diff --git a/notebooks/AmazonBeautyDatasetStatistics.ipynb b/notebooks/AmazonBeautyDatasetStatistics.ipynb
index 6d34ff2..379239d 100644
--- a/notebooks/AmazonBeautyDatasetStatistics.ipynb
+++ b/notebooks/AmazonBeautyDatasetStatistics.ipynb
@@ -405,7 +405,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": ".venv",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -419,7 +419,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.6"
+ "version": "3.10.12"
}
},
"nbformat": 4,
diff --git a/notebooks/LsvdDownload.ipynb b/notebooks/LsvdDownload.ipynb
new file mode 100644
index 0000000..c57e1af
--- /dev/null
+++ b/notebooks/LsvdDownload.ipynb
@@ -0,0 +1,574 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "SbkKok0dfjjS"
+ },
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import polars as pl"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'1.8.2'"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pl.__version__"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "HF_ENDPOINT=\"http://huggingface.proxy\" hf download deepvk/VK-LSVD --repo-type dataset --include \"metadata/*\" --local-dir /home/jovyan/IRec/sigir/lsvd_data/raw\n",
+ "\n",
+ "HF_ENDPOINT=\"http://huggingface.proxy\" hf download deepvk/VK-LSVD --repo-type dataset --include \"subsamples/ur0.01_ir0.01/*\" --local-dir /home/jovyan/IRec/sigir/lsvd_data/raw\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Разбиение сабсэмплов на базовую, гэп, вал и тест части\n",
+ "\n",
+ "Добавляется колонка original_order чтобы сохранять порядок внутри каждой из частей"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(8, 1, 1, 1, 11, 9)"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "subsample_name = 'ur0.01_ir0.01'\n",
+ "content_embedding_size = 256\n",
+ "DATASET_PATH = \"/home/jovyan/IRec/sigir/lsvd_data/raw\"\n",
+ "\n",
+ "metadata_files = ['metadata/users_metadata.parquet',\n",
+ " 'metadata/items_metadata.parquet',\n",
+ " 'metadata/item_embeddings.npz']\n",
+ "\n",
+ "BASE_WEEKS = (15, 23)\n",
+ "GAP_WEEKS = (23, 24) #увеличить гэп\n",
+ "VAL_WEEKS = (24, 25)\n",
+ "\n",
+ "base_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n",
+ " for i in range(BASE_WEEKS[0], BASE_WEEKS[1])]\n",
+ "\n",
+ "gap_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n",
+ " for i in range(GAP_WEEKS[0], GAP_WEEKS[1])]\n",
+ "\n",
+ "val_interactions_files = [f'subsamples/{subsample_name}/train/week_{i:02}.parquet'\n",
+ " for i in range(VAL_WEEKS[0], VAL_WEEKS[1])]\n",
+ "\n",
+ "test_interactions_files = [f'subsamples/{subsample_name}/validation/week_25.parquet']\n",
+ "\n",
+ "all_interactions_files = base_interactions_files + gap_interactions_files + val_interactions_files + test_interactions_files\n",
+ "\n",
+ "base_with_gap_interactions_files = base_interactions_files + gap_interactions_files\n",
+ "\n",
+ "len(base_interactions_files), len(gap_interactions_files), len(val_interactions_files), len(test_interactions_files), len(all_interactions_files), len(base_with_gap_interactions_files)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_parquet_interactions(data_files):\n",
+ " data_interactions = pl.concat([pl.scan_parquet(f'{DATASET_PATH}/{file}')\n",
+ " for file in data_files])\n",
+ " data_interactions = data_interactions.collect(streaming=True)\n",
+ " data_interactions = data_interactions.with_row_index(\"original_order\")\n",
+ " return data_interactions\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "base_interactions = get_parquet_interactions(base_interactions_files)\n",
+ "gap_interactions = get_parquet_interactions(gap_interactions_files)\n",
+ "val_interactions = get_parquet_interactions(val_interactions_files)\n",
+ "test_interactions = get_parquet_interactions(test_interactions_files)\n",
+ "all_data_interactions = get_parquet_interactions(all_interactions_files)\n",
+ "base_with_gap_interactions = get_parquet_interactions(base_with_gap_interactions_files)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Загрузка и фильтрация эмбеддингов"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "all_data_users = all_data_interactions.select('user_id').unique()\n",
+ "all_data_items = all_data_interactions.select('item_id').unique()\n",
+ "\n",
+ "item_ids = np.load(f\"{DATASET_PATH}/metadata/item_embeddings.npz\")['item_id']\n",
+ "item_embeddings = np.load(f\"{DATASET_PATH}/metadata/item_embeddings.npz\")['embedding']\n",
+ "\n",
+ "mask = np.isin(item_ids, all_data_items.to_numpy())\n",
+ "item_ids = item_ids[mask]\n",
+ "item_embeddings = item_embeddings[mask]\n",
+ "item_embeddings = item_embeddings[:, :content_embedding_size]\n",
+ "\n",
+ "users_metadata = pl.read_parquet(f\"{DATASET_PATH}/metadata/users_metadata.parquet\")\n",
+ "items_metadata = pl.read_parquet(f\"{DATASET_PATH}/metadata/items_metadata.parquet\")\n",
+ "\n",
+ "users_metadata = users_metadata.join(all_data_users, on='user_id')\n",
+ "items_metadata = items_metadata.join(all_data_items, on='item_id')\n",
+ "items_metadata = items_metadata.join(pl.DataFrame({'item_id': item_ids, \n",
+ " 'embedding': item_embeddings}), on='item_id')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Сжатие айтем айди и ремапинг"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Total users: 79074, Total items: 62758\n"
+ ]
+ }
+ ],
+ "source": [
+ "all_data_items = all_data_interactions.select('item_id').unique()\n",
+ "all_data_users = all_data_interactions.select('user_id').unique()\n",
+ "\n",
+ "unique_items_sorted = all_data_items.sort('item_id').with_row_index('new_item_id')\n",
+ "global_item_mapping = dict(zip(unique_items_sorted['item_id'], unique_items_sorted['new_item_id']))\n",
+ "\n",
+ "print(f\"Total users: {all_data_users.shape[0]}, Total items: {len(global_item_mapping)}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def remap_interactions(df, mapping):\n",
+ " return df.with_columns(\n",
+ " pl.col('item_id')\n",
+ " .map_elements(lambda x: mapping.get(x, None), return_dtype=pl.UInt32)\n",
+ " )\n",
+ "\n",
+ "base_interactions_remapped = remap_interactions(base_interactions, global_item_mapping)\n",
+ "gap_interactions_remapped = remap_interactions(gap_interactions, global_item_mapping)\n",
+ "test_interactions_remapped = remap_interactions(test_interactions, global_item_mapping)\n",
+ "val_interactions_remapped = remap_interactions(val_interactions, global_item_mapping)\n",
+ "all_data_interactions_remapped = remap_interactions(all_data_interactions, global_item_mapping)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "del base_interactions, gap_interactions, test_interactions, val_interactions, all_data_interactions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "base_with_gap_interactions_remapped = remap_interactions(base_with_gap_interactions, global_item_mapping)\n",
+ "del base_with_gap_interactions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "items_metadata_remapped = remap_interactions(items_metadata, global_item_mapping)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Группировка по юзер айди"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "interactions count: (1244323, 13)\n",
+ "users count: (74862, 3)\n",
+ "interactions count: (176791, 13)\n",
+ "users count: (44444, 3)\n",
+ "interactions count: (170111, 13)\n",
+ "users count: (43370, 3)\n",
+ "interactions count: (164151, 13)\n",
+ "users count: (43233, 3)\n",
+ "interactions count: (1755376, 13)\n",
+ "users count: (79074, 3)\n"
+ ]
+ }
+ ],
+ "source": [
+ "def get_grouped_interactions(data_interactions):\n",
+ " print(f\"interactions count: {data_interactions.shape}\")\n",
+ " data_res = (\n",
+ " data_interactions\n",
+ " .select(['original_order', 'user_id', 'item_id'])\n",
+ " .group_by('user_id')\n",
+ " .agg(\n",
+ " pl.col('item_id')\n",
+ " .sort_by(pl.col('original_order'))\n",
+ " .alias('item_ids'),\n",
+ " pl.col('original_order').alias('timestamps')\n",
+ " )\n",
+ " .rename({'user_id': 'uid'})\n",
+ " )\n",
+ " print(f\"users count: {data_res.shape}\")\n",
+ " return data_res\n",
+ "base_interactions_grouped = get_grouped_interactions(base_interactions_remapped)\n",
+ "gap_interactions_grouped = get_grouped_interactions(gap_interactions_remapped)\n",
+ "test_interactions_grouped = get_grouped_interactions(test_interactions_remapped)\n",
+ "val_interactions_grouped = get_grouped_interactions(val_interactions_remapped)\n",
+ "all_data_interactions_grouped = get_grouped_interactions(all_data_interactions_remapped)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "
shape: (1, 3)| uid | item_ids | timestamps |
|---|
| u32 | list[u32] | list[u32] |
| 2655558 | [16621, 42990, … 51285] | [46109, 59132, … 1209536] |
"
+ ],
+ "text/plain": [
+ "shape: (1, 3)\n",
+ "┌─────────┬─────────────────────────┬───────────────────────────┐\n",
+ "│ uid ┆ item_ids ┆ timestamps │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ u32 ┆ list[u32] ┆ list[u32] │\n",
+ "╞═════════╪═════════════════════════╪═══════════════════════════╡\n",
+ "│ 2655558 ┆ [16621, 42990, … 51285] ┆ [46109, 59132, … 1209536] │\n",
+ "└─────────┴─────────────────────────┴───────────────────────────┘"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "base_interactions_grouped.head(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "del base_interactions_remapped, gap_interactions_remapped, test_interactions_remapped, val_interactions_remapped"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "interactions count: (1421114, 13)\n",
+ "users count: (76483, 3)\n"
+ ]
+ }
+ ],
+ "source": [
+ "base_with_gap_interactions_grouped = get_grouped_interactions(base_with_gap_interactions_remapped)\n",
+ "del base_with_gap_interactions_remapped"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Сохранение"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Сохранён маппинг: /home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/global_item_mapping.json\n"
+ ]
+ }
+ ],
+ "source": [
+ "import json\n",
+ "OUTPUT_DIR = \"/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows\"\n",
+ "\n",
+ "mapping_output_path = f\"{OUTPUT_DIR}/global_item_mapping.json\"\n",
+ "\n",
+ "with open(mapping_output_path, 'w') as f:\n",
+ " json.dump({str(k): v for k, v in global_item_mapping.items()}, f, indent=2)\n",
+ "\n",
+ "print(f\"Сохранён маппинг: {mapping_output_path}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Сохранен файл: items_metadata_remapped\n",
+ "Сохранен файл: items_metadata_remapped_old\n",
+ "Сохранен файл: base_interactions_grouped\n",
+ "Сохранен файл: gap_interactions_grouped\n",
+ "Сохранен файл: test_interactions_grouped\n",
+ "Сохранен файл: val_interactions_grouped\n",
+ "Сохранен файл: base_with_gap_interactions_grouped\n",
+ "Сохранен файл: all_data_interactions_grouped\n",
+ "Сохранен файл: all_data_interactions_remapped\n"
+ ]
+ }
+ ],
+ "source": [
+ "def write_parquet(output_dir, data, file_name):\n",
+ " output_parquet_path = f\"{output_dir}/{file_name}.parquet\"\n",
+ " data.write_parquet(output_parquet_path)\n",
+ " print(f\"Сохранен файл: {file_name}\")\n",
+ "\n",
+ "write_parquet(OUTPUT_DIR, items_metadata_remapped, \"items_metadata_remapped\")\n",
+ "write_parquet(OUTPUT_DIR, items_metadata, \"items_metadata_remapped_old\")\n",
+ "\n",
+ "write_parquet(OUTPUT_DIR, base_interactions_grouped, \"base_interactions_grouped\")\n",
+ "write_parquet(OUTPUT_DIR, gap_interactions_grouped, \"gap_interactions_grouped\")\n",
+ "write_parquet(OUTPUT_DIR, test_interactions_grouped, \"test_interactions_grouped\")\n",
+ "write_parquet(OUTPUT_DIR, val_interactions_grouped, \"val_interactions_grouped\")\n",
+ "write_parquet(OUTPUT_DIR, base_with_gap_interactions_grouped, \"base_with_gap_interactions_grouped\")\n",
+ "\n",
+ "write_parquet(OUTPUT_DIR, all_data_interactions_grouped, \"all_data_interactions_grouped\")\n",
+ "\n",
+ "write_parquet(OUTPUT_DIR, all_data_interactions_remapped, \"all_data_interactions_remapped\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(64, 64)"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(list(items_metadata_remapped.head(1)['embedding'].item())), len(list(items_metadata.head(1)['embedding'].item()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(62758, 5)"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "items_metadata_remapped.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (1, 5)| item_id | author_id | duration | train_interactions_rank | embedding |
|---|
| u32 | u32 | u8 | u32 | array[f32, 64] |
| 0 | 1249424 | 9 | 771612 | [-0.503418, 0.201538, … 0.007988] |
"
+ ],
+ "text/plain": [
+ "shape: (1, 5)\n",
+ "┌─────────┬───────────┬──────────┬─────────────────────────┬─────────────────────────────────┐\n",
+ "│ item_id ┆ author_id ┆ duration ┆ train_interactions_rank ┆ embedding │\n",
+ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
+ "│ u32 ┆ u32 ┆ u8 ┆ u32 ┆ array[f32, 64] │\n",
+ "╞═════════╪═══════════╪══════════╪═════════════════════════╪═════════════════════════════════╡\n",
+ "│ 0 ┆ 1249424 ┆ 9 ┆ 771612 ┆ [-0.503418, 0.201538, … 0.0079… │\n",
+ "└─────────┴───────────┴──────────┴─────────────────────────┴─────────────────────────────────┘"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "items_metadata_remapped.head(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 3)| uid | item_ids | timestamps |
|---|
| u32 | list[u32] | list[u32] |
| 4465123 | [28298, 3829, … 28995] | [257260, 272293, … 1390041] |
| 3043171 | [8638, 23487, … 15086] | [6628, 11364, … 1370935] |
| 2757146 | [56345, 56828, … 37056] | [194522, 217739, … 1390752] |
| 1148408 | [40326, 42152] | [427153, 1367211] |
| 2537065 | [27766, 39966, … 19887] | [9428, 35459, … 1214991] |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌─────────┬─────────────────────────┬─────────────────────────────┐\n",
+ "│ uid ┆ item_ids ┆ timestamps │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ u32 ┆ list[u32] ┆ list[u32] │\n",
+ "╞═════════╪═════════════════════════╪═════════════════════════════╡\n",
+ "│ 4465123 ┆ [28298, 3829, … 28995] ┆ [257260, 272293, … 1390041] │\n",
+ "│ 3043171 ┆ [8638, 23487, … 15086] ┆ [6628, 11364, … 1370935] │\n",
+ "│ 2757146 ┆ [56345, 56828, … 37056] ┆ [194522, 217739, … 1390752] │\n",
+ "│ 1148408 ┆ [40326, 42152] ┆ [427153, 1367211] │\n",
+ "│ 2537065 ┆ [27766, 39966, … 19887] ┆ [9428, 35459, … 1214991] │\n",
+ "└─────────┴─────────────────────────┴─────────────────────────────┘"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "base_with_gap_interactions_grouped.head()"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/scripts/plum-lsvd/callbacks.py b/scripts/plum-lsvd/callbacks.py
new file mode 100644
index 0000000..43ec460
--- /dev/null
+++ b/scripts/plum-lsvd/callbacks.py
@@ -0,0 +1,64 @@
+import torch
+
+import irec.callbacks as cb
+from irec.runners import TrainingRunner, TrainingRunnerContext
+
+class InitCodebooks(cb.TrainingCallback):
+ def __init__(self, dataloader):
+ super().__init__()
+ self._dataloader = dataloader
+
+ @torch.no_grad()
+ def before_run(self, runner: TrainingRunner):
+ for i in range(len(runner.model.codebooks)):
+ X = next(iter(self._dataloader))['embedding']
+ idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])]
+ remainder = runner.model.encoder(X[idx])
+
+ for j in range(i):
+ codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j])
+ codebook_vectors = runner.model.codebooks[j][codebook_indices]
+ remainder = remainder - codebook_vectors
+
+ runner.model.codebooks[i].data = remainder.detach()
+
+
+class FixDeadCentroids(cb.TrainingCallback):
+ def __init__(self, dataloader):
+ super().__init__()
+ self._dataloader = dataloader
+
+ def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext):
+ for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)):
+ context.metrics[f'num_dead/{i}'] = num_fixed
+
+ @torch.no_grad()
+ def fix_dead_codebooks(self, runner: TrainingRunner):
+ num_fixed = []
+ for codebook_idx, codebook in enumerate(runner.model.codebooks):
+ centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device)
+ random_batch = next(iter(self._dataloader))['embedding']
+
+ for batch in self._dataloader:
+ remainder = runner.model.encoder(batch['embedding'])
+ for l in range(codebook_idx):
+ ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
+ remainder = remainder - runner.model.codebooks[l][ind]
+
+ indices = runner.model.get_codebook_indices(remainder, codebook)
+ centroid_counts.scatter_add_(0, indices, torch.ones_like(indices))
+
+ dead_mask = (centroid_counts == 0)
+ num_dead = int(dead_mask.sum().item())
+ num_fixed.append(num_dead)
+ if num_dead == 0:
+ continue
+
+ remainder = runner.model.encoder(random_batch)
+ for l in range(codebook_idx):
+ ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
+ remainder = remainder - runner.model.codebooks[l][ind]
+ remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead]
+ codebook[dead_mask] = remainder.detach()
+
+ return num_fixed
diff --git a/scripts/plum-lsvd/cooc_data.py b/scripts/plum-lsvd/cooc_data.py
new file mode 100644
index 0000000..7cea906
--- /dev/null
+++ b/scripts/plum-lsvd/cooc_data.py
@@ -0,0 +1,117 @@
+import json
+from collections import defaultdict, Counter
+from data import InteractionsDatasetParquet
+from collections import defaultdict, Counter
+
+
+class CoocMappingDataset:
+ def __init__(
+ self,
+ train_sampler,
+ num_items,
+ cooccur_counter_mapping=None
+ ):
+ self._train_sampler = train_sampler
+ self._num_items = num_items
+ self._cooccur_counter_mapping = cooccur_counter_mapping
+
+ @classmethod
+ def create(cls, inter_json_path, window_size):
+ max_item_id = 0
+ train_dataset = []
+
+ with open(inter_json_path, 'r') as f:
+ user_interactions = json.load(f)
+
+ for user_id_str, item_ids in user_interactions.items():
+ user_id = int(user_id_str)
+ if item_ids:
+ max_item_id = max(max_item_id, max(item_ids))
+ if len(item_ids) >= 5:
+ print(f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items')
+ train_dataset.append({
+ 'user_ids': [user_id],
+ 'item_ids': item_ids[:-2],
+ })
+
+
+ cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size)
+ print(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}')
+
+
+ train_sampler = train_dataset
+
+
+ return cls(
+ train_sampler=train_sampler,
+ num_items=max_item_id + 1,
+ cooccur_counter_mapping=cooccur_counter_mapping
+ )
+
+
+ @classmethod
+ def create_from_split_part(
+ cls,
+ train_inter_parquet_path,
+ window_size,
+ ):
+
+ max_item_id = 0
+ train_dataset = []
+
+
+ train_interactions = InteractionsDatasetParquet(train_inter_parquet_path)
+
+ actions_num = 0
+ for session in train_interactions:
+ user_id, item_ids = int(session['user_id']), session['item_ids']
+ if item_ids.any():
+ max_item_id = max(max_item_id, max(item_ids))
+ actions_num += len(item_ids)
+ train_dataset.append({
+ 'user_ids': [user_id],
+ 'item_ids': item_ids,
+ })
+
+
+ print(f'Train: {len(train_dataset)} users')
+ print(f'Max item ID: {max_item_id}')
+ print(f"Actions num: {actions_num}")
+
+
+ cooccur_counter_mapping = cls.build_cooccur_counter_mapping(
+ train_dataset,
+ window_size=window_size
+ )
+
+
+ print(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items')
+
+
+ return cls(
+ train_sampler=train_dataset,
+ num_items=max_item_id + 1,
+ cooccur_counter_mapping=cooccur_counter_mapping
+ )
+
+
+
+ @staticmethod
+ def build_cooccur_counter_mapping(train_dataset, window_size):
+ cooccur_counts = defaultdict(Counter)
+ for session in train_dataset:
+ items = session['item_ids']
+ for i in range(len(items)):
+ item_i = items[i]
+ for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)):
+ if i != j:
+ cooccur_counts[item_i][items[j]] += 1
+ max_hist_len = max(len(counter) for counter in cooccur_counts.values()) if cooccur_counts else 0
+ print(f"Max cooccurrence history length is {max_hist_len}")
+ return cooccur_counts
+
+
+
+ @property
+ def cooccur_counter_mapping(self):
+ return self._cooccur_counter_mapping
\ No newline at end of file
diff --git a/scripts/plum-lsvd/data.py b/scripts/plum-lsvd/data.py
new file mode 100644
index 0000000..5a780fb
--- /dev/null
+++ b/scripts/plum-lsvd/data.py
@@ -0,0 +1,87 @@
+import numpy as np
+import pickle
+
+from irec.data.base import BaseDataset
+from irec.data.transforms import Transform
+
+
+import polars as pl
+
+class InteractionsDatasetParquet(BaseDataset):
+ def __init__(self, data_path, max_items=None):
+ self.df = pl.read_parquet(data_path)
+ assert 'uid' in self.df.columns, "Missing 'uid' column"
+ assert 'item_ids' in self.df.columns, "Missing 'item_ids' column"
+ print(f"Dataset loaded: {len(self.df)} users")
+
+ if max_items is not None:
+ self.df = self.df.with_columns(
+ pl.col("item_ids").list.slice(-max_items).alias("item_ids")
+ )
+
+ def __getitem__(self, idx):
+ row = self.df.row(idx, named=True)
+ return {
+ 'user_id': row['uid'],
+ 'item_ids': np.array(row['item_ids'], dtype=np.uint32),
+ }
+
+ def __len__(self):
+ return len(self.df)
+
+ def __iter__(self):
+ for idx in range(len(self)):
+ yield self[idx]
+
+
+class EmbeddingDatasetParquet(BaseDataset):
+ def __init__(self, data_path):
+ self.df = pl.read_parquet(data_path)
+ self.item_ids = np.array(self.df['item_id'], dtype=np.int64)
+ self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32)
+ print(f"embedding dim: {self.embeddings[0].shape}")
+
+ def __getitem__(self, idx):
+ index = self.item_ids[idx]
+ tensor_emb = self.embeddings[idx]
+ return {
+ 'item_id': index,
+ 'embedding': tensor_emb,
+ 'embedding_dim': len(tensor_emb)
+ }
+
+ def __len__(self):
+ return len(self.embeddings)
+
+
+class EmbeddingDataset(BaseDataset):
+ def __init__(self, data_path):
+ self.data_path = data_path
+ with open(data_path, 'rb') as f:
+ self.data = pickle.load(f)
+
+ self.item_ids = np.array(self.data['item_id'], dtype=np.int64)
+ self.embeddings = np.array(self.data['embedding'], dtype=np.float32)
+
+ def __getitem__(self, idx):
+ index = self.item_ids[idx]
+ tensor_emb = self.embeddings[idx]
+ return {
+ 'item_id': index,
+ 'embedding': tensor_emb,
+ 'embedding_dim': len(tensor_emb)
+ }
+
+ def __len__(self):
+ return len(self.embeddings)
+
+
+class ProcessEmbeddings(Transform):
+ def __init__(self, embedding_dim, keys):
+ self.embedding_dim = embedding_dim
+ self.keys = keys
+
+ def __call__(self, batch):
+ for key in self.keys:
+ batch[key] = batch[key].reshape(-1, self.embedding_dim)
+ return batch
\ No newline at end of file
diff --git a/scripts/plum-lsvd/infer_default.py b/scripts/plum-lsvd/infer_default.py
new file mode 100644
index 0000000..b15fb6d
--- /dev/null
+++ b/scripts/plum-lsvd/infer_default.py
@@ -0,0 +1,146 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+
+from data import EmbeddingDataset, ProcessEmbeddings
+from models import PlumRQVAE
+
+# ПУТИ
+IREC_PATH = '/home/jovyan/IRec/'
+EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl'
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_plum_rqvae_beauty_ws_2_best_0.0051.pth'
+RESULTS_PATH = os.path.join(IREC_PATH, 'results')
+
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+BATCH_SIZE = 1024
+
+INPUT_DIM = 4096
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+
+BETA = 0.25
+
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDataset(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/infer_plum_4.1.py b/scripts/plum-lsvd/infer_plum_4.1.py
new file mode 100644
index 0000000..bb70a9d
--- /dev/null
+++ b/scripts/plum-lsvd/infer_plum_4.1.py
@@ -0,0 +1,146 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import PlumRQVAE
+
+# ЭКСПЕРИМЕНТ С ПОЛНОЙ ИСТОРИЕЙ
+IREC_PATH = '../../'
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e35_best_0.0096.pth'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/items_metadata_remapped.parquet"
+
+RESULTS_PATH = os.path.join(IREC_PATH, 'results')
+
+WINDOW_SIZE = 2
+CODEBOOK_SIZE = 512
+K = 2000
+EXPERIMENT_NAME = f'4-1_vk_lsvd_ods_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_8w_e_35'
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 32
+NUM_CODEBOOKS = 3
+BETA = 0.25
+
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/infer_plum_4.2.py b/scripts/plum-lsvd/infer_plum_4.2.py
new file mode 100644
index 0000000..977c0b5
--- /dev/null
+++ b/scripts/plum-lsvd/infer_plum_4.2.py
@@ -0,0 +1,146 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import PlumRQVAE
+
+# ЭКСПЕРИМЕНТ С ОБРЕЗАННОЙ ИСТОРИЕЙ
+IREC_PATH = '../../'
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-2_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e35_best_0.0096.pth'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/items_metadata_remapped.parquet"
+
+RESULTS_PATH = os.path.join(IREC_PATH, 'results')
+
+WINDOW_SIZE = 2
+CODEBOOK_SIZE = 512
+K = 2000
+EXPERIMENT_NAME = f'4-2_vk_lsvd_ods_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_8w_e_35'
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 32
+NUM_CODEBOOKS = 3
+BETA = 0.25
+
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/infer_rqvae.py b/scripts/plum-lsvd/infer_rqvae.py
new file mode 100644
index 0000000..53a587c
--- /dev/null
+++ b/scripts/plum-lsvd/infer_rqvae.py
@@ -0,0 +1,161 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from collections import Counter
+from models import PlumRQVAE
+
+# ПУТИ
+IREC_PATH = '/home/jovyan/IRec/'
+EMBEDDINGS_PATH = '/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/items_metadata_remapped.parquet'
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/rqvae_vk_lsvd_cz_512_8-weeks_best_0.009.pth'
+RESULTS_PATH = os.path.join(IREC_PATH, 'results')
+
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'rqvae_vk_lsvd_cz_512_8-weeks'
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+
+BETA = 0.25
+
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ collision_stats = []
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ collision_stats.append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ if collision_stats:
+ max_col_tok = max(collision_stats)
+ avg_col_tok = np.mean(collision_stats)
+ collision_distribution = Counter(collision_stats)
+
+ print(f"Max collision token: {max_col_tok}")
+ print(f"Avg collision token: {avg_col_tok:.2f}")
+ print(f"Total items with collisions: {len(collision_stats)}")
+ print(f"Collision solver distribution: {dict(collision_distribution)}")
+ else:
+ print("No collisions detected")
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/models.py b/scripts/plum-lsvd/models.py
new file mode 100644
index 0000000..d475712
--- /dev/null
+++ b/scripts/plum-lsvd/models.py
@@ -0,0 +1,131 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class PlumRQVAE(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ num_codebooks,
+ codebook_size,
+ embedding_dim,
+ beta=0.25,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=0.0,
+ ):
+ super().__init__()
+ self.register_buffer('beta', torch.tensor(beta))
+ self.temperature = temperature
+
+ self.input_dim = input_dim
+ self.num_codebooks = num_codebooks
+ self.codebook_size = codebook_size
+ self.embedding_dim = embedding_dim
+ self.quant_loss_weight = quant_loss_weight
+
+ self.contrastive_loss_weight = contrastive_loss_weight
+
+ self.encoder = self.make_encoding_tower(input_dim, embedding_dim)
+ self.decoder = self.make_encoding_tower(embedding_dim, input_dim)
+
+ self.codebooks = torch.nn.ParameterList()
+ for _ in range(num_codebooks):
+ cb = torch.FloatTensor(codebook_size, embedding_dim)
+ #nn.init.normal_(cb)
+ self.codebooks.append(cb)
+
+ @staticmethod
+ def make_encoding_tower(d1, d2, bias=False):
+ return torch.nn.Sequential(
+ nn.Linear(d1, d1),
+ nn.ReLU(),
+ nn.Linear(d1, d2),
+ nn.ReLU(),
+ nn.Linear(d2, d2, bias=bias)
+ )
+
+ @staticmethod
+ def get_codebook_indices(remainder, codebook):
+ dist = torch.cdist(remainder, codebook)
+ return dist.argmin(dim=-1)
+
+ def _quantize_representation(self, latent_vector):
+ latent_restored = 0
+ remainder = latent_vector
+
+ for codebook in self.codebooks:
+ codebook_indices = self.get_codebook_indices(remainder, codebook)
+ quantized = codebook[codebook_indices]
+ codebook_vectors = remainder + (quantized - remainder).detach()
+ latent_restored += codebook_vectors
+ remainder = remainder - codebook_vectors
+
+ return latent_restored
+
+ def contrastive_loss(self, p_i, p_i_star):
+ N_b = p_i.size(0)
+
+ p_i = F.normalize(p_i, p=2, dim=-1) #TODO посмотреть без нормалайза
+ p_i_star = F.normalize(p_i_star, p=2, dim=-1)
+
+ similarities = torch.matmul(p_i, p_i_star.T) / self.temperature
+
+ labels = torch.arange(N_b, dtype=torch.long, device=p_i.device)
+
+ loss = F.cross_entropy(similarities, labels)
+
+ return loss
+
+ def forward(self, inputs):
+ latent_vector = self.encoder(inputs['embedding'])
+ item_ids = inputs['item_id']
+
+ latent_restored = 0
+ rqvae_loss = 0
+ clusters = []
+ remainder = latent_vector
+
+ for codebook in self.codebooks:
+ codebook_indices = self.get_codebook_indices(remainder, codebook)
+ clusters.append(codebook_indices)
+
+ quantized = codebook[codebook_indices]
+ codebook_vectors = remainder + (quantized - remainder).detach()
+
+ rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach())
+ rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach())
+
+ latent_restored += codebook_vectors
+ remainder = remainder - codebook_vectors
+
+ embeddings_restored = self.decoder(latent_restored)
+ recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding'])
+
+ if 'cooccurrence_embedding' in inputs:
+ cooccurrence_latent = self.encoder(inputs['cooccurrence_embedding'].to(latent_restored.device))
+ cooccurrence_restored = self._quantize_representation(cooccurrence_latent)
+ con_loss = self.contrastive_loss(latent_restored, cooccurrence_restored)
+ else:
+ con_loss = torch.as_tensor(0.0, device=latent_vector.device)
+
+ loss = (
+ recon_loss
+ + self.quant_loss_weight * rqvae_loss
+ + self.contrastive_loss_weight * con_loss
+ ).mean()
+
+ clusters_counts = []
+ for cluster in clusters:
+ clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size))
+
+ return loss, {
+ 'loss': loss.item(),
+ 'recon_loss': recon_loss.mean().item(),
+ 'rqvae_loss': rqvae_loss.mean().item(),
+ 'con_loss': con_loss.item(),
+
+ 'clusters_counts': clusters_counts,
+ 'clusters': torch.stack(clusters).T,
+ 'embedding_hat': embeddings_restored,
+ }
\ No newline at end of file
diff --git a/scripts/plum-lsvd/train_plum_4.1.py b/scripts/plum-lsvd/train_plum_4.1.py
new file mode 100644
index 0000000..85027b3
--- /dev/null
+++ b/scripts/plum-lsvd/train_plum_4.1.py
@@ -0,0 +1,180 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddingsVectorized
+from cooc_data import CoocMappingDataset
+
+# ЭКСПЕРИМЕНТ С ПОЛНОЙ ИСТОРИЕЙ
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 35
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 2
+K=2000
+
+EXPERIMENT_NAME = f'4-1_vk_lsvd_ods_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_8w_e{NUM_EPOCHS}'
+INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/base_with_gap_interactions_grouped.parquet"
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/items_metadata_remapped.parquet"
+IREC_PATH = '../../'
+
+print(INTER_TRAIN_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH,
+ )
+
+ data = CoocMappingDataset.create_from_split_part(
+ train_inter_parquet_path=INTER_TRAIN_PATH,
+ window_size=WINDOW_SIZE
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE)
+ all_item_ids.append(item_id)
+
+ # add_cooc_transform = AddWeightedCooccurrenceEmbeddings(data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids, K)
+ add_cooc_transform = AddWeightedCooccurrenceEmbeddingsVectorized(
+ cooccur_counts=data.cooccur_counter_mapping,
+ item_id_to_embedding=item_id_to_embedding,
+ all_item_ids=all_item_ids,
+ device=DEVICE,
+ max_neighbors=K,
+ seed=42
+ )
+
+ train_dataloader = DataLoader( #call в основном потоке делается нужно исправить
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform
+ ).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform)
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/train_plum_4.2.py b/scripts/plum-lsvd/train_plum_4.2.py
new file mode 100644
index 0000000..de1864d
--- /dev/null
+++ b/scripts/plum-lsvd/train_plum_4.2.py
@@ -0,0 +1,180 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddingsVectorized
+from cooc_data import CoocMappingDataset
+
+# ЭКСПЕРИМЕНТ С ОБРЕЗАННОЙ ИСТОРИЕЙ
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 35
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 2
+K=2000
+
+EXPERIMENT_NAME = f'4-2_vk_lsvd_ods_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_8w_e{NUM_EPOCHS}'
+INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/base_interactions_grouped.parquet"
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/items_metadata_remapped.parquet"
+IREC_PATH = '../../'
+
+print(INTER_TRAIN_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH,
+ )
+
+ data = CoocMappingDataset.create_from_split_part(
+ train_inter_parquet_path=INTER_TRAIN_PATH,
+ window_size=WINDOW_SIZE
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE)
+ all_item_ids.append(item_id)
+
+ # add_cooc_transform = AddWeightedCooccurrenceEmbeddings(data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids, K)
+ add_cooc_transform = AddWeightedCooccurrenceEmbeddingsVectorized(
+ cooccur_counts=data.cooccur_counter_mapping,
+ item_id_to_embedding=item_id_to_embedding,
+ all_item_ids=all_item_ids,
+ device=DEVICE,
+ max_neighbors=K,
+ seed=42
+ )
+
+ train_dataloader = DataLoader( #call в основном потоке делается нужно исправить
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform
+ ).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform)
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/train_rqvae.py b/scripts/plum-lsvd/train_rqvae.py
new file mode 100644
index 0000000..ea41b74
--- /dev/null
+++ b/scripts/plum-lsvd/train_rqvae.py
@@ -0,0 +1,174 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddings
+from cooc_data import CoocMappingDataset
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 15
+BATCH_SIZE = 1024
+
+INPUT_DIM = 64
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 3
+MAX_LEN = 500
+K=100
+
+# EXPERIMENT_NAME = f'4-1_vk_lsvd_ods_base_with_gap_cb_{CODEBOOK_SIZE}_ws_{WINDOW_SIZE}_k_{K}_ml_{MAX_LEN}'
+EXPERIMENT_NAME = f'rqvae_vk_lsvd_cz_512_8-weeks'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/items_metadata_remapped.parquet"
+IREC_PATH = '../../'
+
+# print(INTER_TRAIN_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ # data = CoocMappingDataset.create_from_split_part(
+ # train_inter_parquet_path=INTER_TRAIN_PATH,
+ # window_size=WINDOW_SIZE,
+ # max_items=MAX_LEN
+ # )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE)
+ all_item_ids.append(item_id)
+
+ # add_cooc_transform = AddWeightedCooccurrenceEmbeddings(data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
+
+ train_dataloader = DataLoader( #call в основном потоке делается нужно исправить
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ # ).map(add_cooc_transform
+ ).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ # ).map(add_cooc_transform)
+ )
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-lsvd/transforms.py b/scripts/plum-lsvd/transforms.py
new file mode 100644
index 0000000..143002b
--- /dev/null
+++ b/scripts/plum-lsvd/transforms.py
@@ -0,0 +1,287 @@
+import numpy as np
+import pickle
+import torch
+from typing import Dict, List
+import time
+from collections import defaultdict, Counter
+
+class AddWeightedCooccurrenceEmbeddings:
+ def __init__(self, cooccur_counts, item_id_to_embedding, all_item_ids, top_k):
+ self.cooccur_counts = cooccur_counts
+ self.item_id_to_embedding = item_id_to_embedding
+ self.all_item_ids = all_item_ids
+ self.call_count = 0
+ self.top_k = top_k
+
+ # Предвычисляем top_k для каждого item_id
+ self._top_k_cache = {}
+ self._build_top_k_cache()
+
+ def _build_top_k_cache(self):
+ """Предвычисляет top-k соседей для каждого item_id"""
+ for item_id, counter in self.cooccur_counts.items():
+ if counter and len(counter) > 0:
+ # Сортируем по частоте и берем top_k
+ top_items = counter.most_common(self.top_k)
+ cooc_ids, freqs = zip(*top_items)
+ freqs_array = np.array(freqs, dtype=np.float32)
+ probs = freqs_array / freqs_array.sum()
+
+ self._top_k_cache[item_id] = {
+ 'cooc_ids': cooc_ids,
+ 'probs': probs
+ }
+
+ def __call__(self, batch):
+ self.call_count += 1
+ item_ids = batch['item_id']
+ cooccurrence_embeddings = []
+
+ for idx, item_id in enumerate(item_ids):
+ item_id_val = int(item_id.item()) if torch.is_tensor(item_id) else int(item_id)
+
+ # Используем предвычисленный top-k кэш
+ if item_id_val in self._top_k_cache:
+ cache_entry = self._top_k_cache[item_id_val]
+ cooc_id = np.random.choice(
+ cache_entry['cooc_ids'],
+ p=cache_entry['probs']
+ )
+ else:
+ cooc_id = np.random.choice(self.all_item_ids)
+ if self.call_count % 500 == 0 and idx < 5:
+ print(f" idx={idx}: item_id={item_id_val} fallback random")
+ if self.call_count % 500 == 0 and idx < 5:
+ print(f" idx={idx}: item_id={item_id_val} cooc_id={cooc_id}")
+ cooc_emb = self.item_id_to_embedding.get(cooc_id, batch['embedding'][0])
+ cooccurrence_embeddings.append(cooc_emb)
+
+ batch['cooccurrence_embedding'] = torch.stack(cooccurrence_embeddings)
+ return batch
+
+#запустить сасрек, леттер, sasrec << tiger < letter < plum
+
+
+class AddWeightedCooccurrenceEmbeddingsVectorized:
+
+ def __init__(
+ self,
+ cooccur_counts: Dict[int, Dict[int, int]],
+ item_id_to_embedding: Dict[int, torch.Tensor],
+ all_item_ids: List[int],
+ device: torch.device,
+ limit_neighbors: bool = True,
+ max_neighbors: int = 256,
+ seed: int = 42,
+ verbose: bool = True
+ ):
+ self.device = device
+ self.call_count = 0
+ self.limit_neighbors = limit_neighbors
+ self.max_neighbors = max_neighbors
+ self.seed = seed
+ self.verbose = verbose
+
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ if self.verbose:
+ print(f"\n{'='*80}")
+ print(f"Initializing AddWeightedCooccurrenceEmbeddingsVectorized")
+ print(f"{'='*80}")
+ init_start = time.time()
+
+ all_item_ids_sorted = sorted(all_item_ids)
+ self.item_id_to_idx = {item_id: idx for idx, item_id in enumerate(all_item_ids_sorted)}
+ self.idx_to_item_id = torch.tensor(all_item_ids_sorted, device=device, dtype=torch.long)
+
+ if self.verbose:
+ print(f"[INIT] Sorted {len(all_item_ids)} item IDs and created mappings")
+
+ num_items = len(all_item_ids_sorted)
+ embedding_dim = next(iter(item_id_to_embedding.values())).shape[0]
+
+ if self.verbose:
+ print(f"[INIT] Num items: {num_items}, Embedding dim: {embedding_dim}")
+
+ self.embedding_matrix = torch.zeros(
+ size=(num_items, embedding_dim),
+ device=device,
+ dtype=torch.float32,
+ requires_grad=False
+ )
+
+ emb_load_start = time.time()
+ for item_id, emb in item_id_to_embedding.items():
+ idx = self.item_id_to_idx[item_id]
+ if isinstance(emb, torch.Tensor):
+ self.embedding_matrix[idx] = emb.to(device).detach()
+ else:
+ self.embedding_matrix[idx] = torch.tensor(emb, device=device, dtype=torch.float32)
+
+ if self.verbose:
+ emb_load_time = time.time() - emb_load_start
+ print(f"[INIT] Loaded {len(item_id_to_embedding)} embeddings in {emb_load_time*1000:.2f}ms")
+
+ self._build_cooccurrence_tables(cooccur_counts, num_items)
+
+ if self.verbose:
+ init_time = time.time() - init_start
+ print(f"[INIT] Total initialization time: {init_time*1000:.2f}ms")
+ print(f"{'='*80}\n")
+
+ def _build_cooccurrence_tables(self, cooccur_counts: Dict, num_items: int):
+ if self.verbose:
+ build_start = time.time()
+ print(f"\n[BUILD] Building cooccurrence tables...")
+
+ indexed_cooccur_counts = {}
+ for item_id, neighbors in cooccur_counts.items():
+ if item_id in self.item_id_to_idx:
+ idx = self.item_id_to_idx[item_id]
+ indexed_neighbors = {}
+ for neighbor_id, count in neighbors.items():
+ if neighbor_id in self.item_id_to_idx:
+ neighbor_idx = self.item_id_to_idx[neighbor_id]
+ indexed_neighbors[neighbor_idx] = count
+ if indexed_neighbors:
+ indexed_cooccur_counts[idx] = indexed_neighbors
+
+ if self.verbose:
+ items_with_cooc = len(indexed_cooccur_counts)
+ print(f"[BUILD] Items with cooccurrences: {items_with_cooc}/{num_items}")
+ total_pairs = sum(len(neighbors) for neighbors in indexed_cooccur_counts.values())
+ print(f"[BUILD] Total cooccurrence pairs: {total_pairs}")
+
+ max_actual_neighbors = 0
+ for idx in range(num_items):
+ counter = indexed_cooccur_counts.get(idx)
+ if counter and len(counter) > 0:
+ num_neighbors = len(counter)
+ if self.limit_neighbors:
+ num_neighbors = min(num_neighbors, self.max_neighbors)
+ else:
+ num_neighbors = num_items
+ max_actual_neighbors = max(max_actual_neighbors, num_neighbors)
+
+ if self.limit_neighbors:
+ max_actual_neighbors = min(max_actual_neighbors, self.max_neighbors)
+
+ if self.verbose:
+ print(f"[BUILD] Max neighbors per item: {max_actual_neighbors}")
+
+ neighbors_matrix = torch.zeros(
+ (num_items, max_actual_neighbors),
+ dtype=torch.long,
+ device=self.device,
+ requires_grad=False
+ )
+
+ probs_matrix = torch.zeros(
+ (num_items, max_actual_neighbors),
+ dtype=torch.float32,
+ device=self.device,
+ requires_grad=False
+ )
+
+ valid_mask = torch.zeros(
+ (num_items, max_actual_neighbors),
+ dtype=torch.bool,
+ device=self.device,
+ requires_grad=False
+ )
+
+ matrix_fill_start = time.time()
+
+ for idx in range(num_items):
+ counter = indexed_cooccur_counts.get(idx)
+
+ if counter and len(counter) > 0:
+ cooc_items = sorted(counter.items(), key=lambda x: x, reverse=True)
+ cooc_ids, freqs = zip(*cooc_items)
+ cooc_ids = list(cooc_ids)
+ freqs = np.array(freqs, dtype=np.float32)
+
+ num_neighbors = min(len(cooc_ids), max_actual_neighbors)
+ cooc_ids = cooc_ids[:num_neighbors]
+ freqs = freqs[:num_neighbors]
+
+ probs = freqs / freqs.sum()
+
+ neighbors_matrix[idx, :num_neighbors] = torch.tensor(
+ cooc_ids, dtype=torch.long, device=self.device
+ )
+ probs_matrix[idx, :num_neighbors] = torch.tensor(
+ probs, dtype=torch.float32, device=self.device
+ )
+ valid_mask[idx, :num_neighbors] = True
+
+ else:
+ if max_actual_neighbors >= num_items:
+ neighbors_matrix[idx, :num_items] = torch.arange(num_items, device=self.device)
+ probs_matrix[idx, :num_items] = 1.0 / num_items
+ valid_mask[idx, :num_items] = True
+ else:
+ perm = torch.randperm(num_items, device=self.device)[:max_actual_neighbors]
+ neighbors_matrix[idx] = perm
+ probs_matrix[idx] = 1.0 / max_actual_neighbors
+ valid_mask[idx] = True
+
+ if self.verbose:
+ matrix_fill_time = time.time() - matrix_fill_start
+ print(f"[BUILD] Filled matrices in {matrix_fill_time*1000:.2f}ms")
+
+ self.neighbors_matrix = neighbors_matrix
+ self.probs_matrix = probs_matrix
+ self.valid_mask = valid_mask
+
+ if self.verbose:
+ print(f"[BUILD] neighbors_matrix shape: {neighbors_matrix.shape}")
+ print(f"[BUILD] probs_matrix shape: {probs_matrix.shape}")
+ print(f"[BUILD] valid_mask shape: {valid_mask.shape}")
+ build_time = time.time() - build_start
+ print(f"[BUILD] Total build time: {build_time*1000:.2f}ms")
+
+ def __call__(self, batch):
+ self.call_count += 1
+
+ call_start = time.time()
+
+ item_ids = batch['item_id']
+
+ if not isinstance(item_ids, torch.Tensor):
+ item_ids = torch.tensor(item_ids, device=self.device, dtype=torch.long)
+ else:
+ item_ids = item_ids.to(device=self.device, dtype=torch.long)
+
+ batch_size = item_ids.shape
+
+ indexed_item_ids = torch.tensor(
+ [self.item_id_to_idx.get(int(iid.item()), 0) for iid in item_ids],
+ device=self.device,
+ dtype=torch.long
+ )
+
+ probs = self.probs_matrix[indexed_item_ids]
+ mask = self.valid_mask[indexed_item_ids]
+
+ masked_probs = probs.clone()
+ masked_probs[~mask] = 0.0
+
+ row_sums = masked_probs.sum(dim=1, keepdim=True)
+ row_sums[row_sums == 0] = 1.0
+ masked_probs = masked_probs / row_sums
+
+ neighbor_indices = torch.multinomial(masked_probs, num_samples=1, replacement=True)
+ neighbor_indices = neighbor_indices.squeeze(1)
+
+ cooc_indexed_ids = self.neighbors_matrix[indexed_item_ids, neighbor_indices]
+ cooccurrence_embeddings = self.embedding_matrix[cooc_indexed_ids]
+
+ batch['cooccurrence_embedding'] = cooccurrence_embeddings
+
+ call_time = time.time() - call_start
+ if self.verbose and self.call_count % 1000 == 0:
+ print(f"Call #{self.call_count}: batch_size={batch_size}, {call_time*1000:.2f}ms")
+
+ return batch
\ No newline at end of file
diff --git a/scripts/plum-yambda/callbacks.py b/scripts/plum-yambda/callbacks.py
new file mode 100644
index 0000000..43ec460
--- /dev/null
+++ b/scripts/plum-yambda/callbacks.py
@@ -0,0 +1,64 @@
+import torch
+
+import irec.callbacks as cb
+from irec.runners import TrainingRunner, TrainingRunnerContext
+
+class InitCodebooks(cb.TrainingCallback):
+ def __init__(self, dataloader):
+ super().__init__()
+ self._dataloader = dataloader
+
+ @torch.no_grad()
+ def before_run(self, runner: TrainingRunner):
+ for i in range(len(runner.model.codebooks)):
+ X = next(iter(self._dataloader))['embedding']
+ idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])]
+ remainder = runner.model.encoder(X[idx])
+
+ for j in range(i):
+ codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j])
+ codebook_vectors = runner.model.codebooks[j][codebook_indices]
+ remainder = remainder - codebook_vectors
+
+ runner.model.codebooks[i].data = remainder.detach()
+
+
+class FixDeadCentroids(cb.TrainingCallback):
+ def __init__(self, dataloader):
+ super().__init__()
+ self._dataloader = dataloader
+
+ def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext):
+ for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)):
+ context.metrics[f'num_dead/{i}'] = num_fixed
+
+ @torch.no_grad()
+ def fix_dead_codebooks(self, runner: TrainingRunner):
+ num_fixed = []
+ for codebook_idx, codebook in enumerate(runner.model.codebooks):
+ centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device)
+ random_batch = next(iter(self._dataloader))['embedding']
+
+ for batch in self._dataloader:
+ remainder = runner.model.encoder(batch['embedding'])
+ for l in range(codebook_idx):
+ ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
+ remainder = remainder - runner.model.codebooks[l][ind]
+
+ indices = runner.model.get_codebook_indices(remainder, codebook)
+ centroid_counts.scatter_add_(0, indices, torch.ones_like(indices))
+
+ dead_mask = (centroid_counts == 0)
+ num_dead = int(dead_mask.sum().item())
+ num_fixed.append(num_dead)
+ if num_dead == 0:
+ continue
+
+ remainder = runner.model.encoder(random_batch)
+ for l in range(codebook_idx):
+ ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
+ remainder = remainder - runner.model.codebooks[l][ind]
+ remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead]
+ codebook[dead_mask] = remainder.detach()
+
+ return num_fixed
diff --git a/scripts/plum-yambda/cooc_data.py b/scripts/plum-yambda/cooc_data.py
new file mode 100644
index 0000000..50f2bdd
--- /dev/null
+++ b/scripts/plum-yambda/cooc_data.py
@@ -0,0 +1,108 @@
+import json
+import pickle
+from collections import defaultdict, Counter
+
+import numpy as np
+from loguru import logger
+
+
+import pickle
+from collections import defaultdict, Counter
+
+class CoocMappingDataset:
+ def __init__(
+ self,
+ train_sampler,
+ num_items,
+ cooccur_counter_mapping=None
+ ):
+ self._train_sampler = train_sampler
+ self._num_items = num_items
+ self._cooccur_counter_mapping = cooccur_counter_mapping
+
+ @classmethod
+ def create(cls, inter_json_path, window_size):
+ max_item_id = 0
+ train_dataset, validation_dataset, test_dataset = [], [], []
+
+ with open(inter_json_path, 'r') as f:
+ user_interactions = json.load(f)
+
+ for user_id_str, item_ids in user_interactions.items():
+ user_id = int(user_id_str)
+ if item_ids:
+ max_item_id = max(max_item_id, max(item_ids))
+ assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items'
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids[:-2],
+ })
+
+ cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size)
+ logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}')
+
+ train_sampler = train_dataset
+
+ return cls(
+ train_sampler=train_sampler,
+ num_items=max_item_id + 1,
+ cooccur_counter_mapping=cooccur_counter_mapping
+ )
+
+ @classmethod
+ def create_from_split_part(
+ cls,
+ train_inter_json_path,
+ window_size
+ ):
+
+ max_item_id = 0
+ train_dataset = []
+
+ with open(train_inter_json_path, 'r') as f:
+ train_interactions = json.load(f)
+
+ # Обрабатываем TRAIN
+ for user_id_str, item_ids in train_interactions.items():
+ user_id = int(user_id_str)
+ if item_ids:
+ max_item_id = max(max_item_id, max(item_ids))
+
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids,
+ })
+
+ logger.debug(f'Train: {len(train_dataset)} users')
+ logger.debug(f'Max item ID: {max_item_id}')
+
+ cooccur_counter_mapping = cls.build_cooccur_counter_mapping(
+ train_dataset,
+ window_size=window_size
+ )
+
+ logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items')
+
+ return cls(
+ train_sampler=train_dataset,
+ num_items=max_item_id + 1,
+ cooccur_counter_mapping=cooccur_counter_mapping
+ )
+
+
+ @staticmethod
+ def build_cooccur_counter_mapping(train_dataset, window_size): #TODO передавать время и по нему строить окно
+ cooccur_counts = defaultdict(Counter)
+ for session in train_dataset:
+ items = session['item.ids']
+ for i in range(len(items)):
+ item_i = items[i]
+ for j in range(max(0, i - window_size), min(len(items), i + window_size + 1)):
+ if i != j:
+ cooccur_counts[item_i][items[j]] += 1
+ return cooccur_counts
+
+
+ @property
+ def cooccur_counter_mapping(self):
+ return self._cooccur_counter_mapping
diff --git a/scripts/plum-yambda/data.py b/scripts/plum-yambda/data.py
new file mode 100644
index 0000000..842adb5
--- /dev/null
+++ b/scripts/plum-yambda/data.py
@@ -0,0 +1,62 @@
+import numpy as np
+import pickle
+
+from irec.data.base import BaseDataset
+from irec.data.transforms import Transform
+
+
+import polars as pl
+import numpy as np
+import torch
+
+class EmbeddingDatasetParquet(BaseDataset):
+ def __init__(self, data_path):
+ self.df = pl.read_parquet(data_path)
+ self.item_ids = np.array(self.df['item_id'], dtype=np.int64)
+ self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32)
+ print(f"embedding dim: {self.embeddings[0].shape}")
+
+ def __getitem__(self, idx):
+ index = self.item_ids[idx]
+ tensor_emb = self.embeddings[idx]
+ return {
+ 'item_id': index,
+ 'embedding': tensor_emb,
+ 'embedding_dim': len(tensor_emb)
+ }
+
+ def __len__(self):
+ return len(self.embeddings)
+
+
+class EmbeddingDataset(BaseDataset):
+ def __init__(self, data_path):
+ self.data_path = data_path
+ with open(data_path, 'rb') as f:
+ self.data = pickle.load(f)
+
+ self.item_ids = np.array(self.data['item_id'], dtype=np.int64)
+ self.embeddings = np.array(self.data['embedding'], dtype=np.float32)
+
+ def __getitem__(self, idx):
+ index = self.item_ids[idx]
+ tensor_emb = self.embeddings[idx]
+ return {
+ 'item_id': index,
+ 'embedding': tensor_emb,
+ 'embedding_dim': len(tensor_emb)
+ }
+
+ def __len__(self):
+ return len(self.embeddings)
+
+
+class ProcessEmbeddings(Transform):
+ def __init__(self, embedding_dim, keys):
+ self.embedding_dim = embedding_dim
+ self.keys = keys
+
+ def __call__(self, batch):
+ for key in self.keys:
+ batch[key] = batch[key].reshape(-1, self.embedding_dim)
+ return batch
\ No newline at end of file
diff --git a/scripts/plum-yambda/models.py b/scripts/plum-yambda/models.py
new file mode 100644
index 0000000..a411519
--- /dev/null
+++ b/scripts/plum-yambda/models.py
@@ -0,0 +1,135 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class PlumRQVAE(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ num_codebooks,
+ codebook_size,
+ embedding_dim,
+ beta=0.25,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=0.0,
+ ):
+ super().__init__()
+ self.register_buffer('beta', torch.tensor(beta))
+ self.temperature = temperature
+
+ self.input_dim = input_dim
+ self.num_codebooks = num_codebooks
+ self.codebook_size = codebook_size
+ self.embedding_dim = embedding_dim
+ self.quant_loss_weight = quant_loss_weight
+
+ self.contrastive_loss_weight = contrastive_loss_weight
+
+ self.encoder = self.make_encoding_tower(input_dim, embedding_dim)
+ self.decoder = self.make_encoding_tower(embedding_dim, input_dim)
+
+ self.codebooks = torch.nn.ParameterList()
+ for _ in range(num_codebooks):
+ cb = torch.FloatTensor(codebook_size, embedding_dim)
+ #nn.init.normal_(cb)
+ self.codebooks.append(cb)
+
+ @staticmethod
+ def make_encoding_tower(d1, d2, bias=False):
+ return torch.nn.Sequential(
+ nn.Linear(d1, d1),
+ nn.ReLU(),
+ nn.Linear(d1, d2),
+ nn.ReLU(),
+ nn.Linear(d2, d2, bias=bias)
+ )
+
+ @staticmethod
+ def get_codebook_indices(remainder, codebook):
+ dist = torch.cdist(remainder, codebook)
+ return dist.argmin(dim=-1)
+
+ def _quantize_representation(self, latent_vector):
+ latent_restored = 0
+ remainder = latent_vector
+
+ for codebook in self.codebooks:
+ codebook_indices = self.get_codebook_indices(remainder, codebook)
+ quantized = codebook[codebook_indices]
+ codebook_vectors = remainder + (quantized - remainder).detach()
+ latent_restored += codebook_vectors
+ remainder = remainder - codebook_vectors
+
+ return latent_restored
+
+ def contrastive_loss(self, p_i, p_i_star):
+ N_b = p_i.size(0)
+
+ p_i = F.normalize(p_i, p=2, dim=-1) #TODO посмотреть без нормалайза
+ p_i_star = F.normalize(p_i_star, p=2, dim=-1)
+
+ similarities = torch.matmul(p_i, p_i_star.T) / self.temperature
+
+ labels = torch.arange(N_b, dtype=torch.long, device=p_i.device)
+
+ loss = F.cross_entropy(similarities, labels)
+
+ return loss #только по последней размерности
+
+ def forward(self, inputs):
+ latent_vector = self.encoder(inputs['embedding'])
+ # print(f"latent vector shape: {latent_vector.shape}")
+ # print(f"inputs embedding shape: {inputs['embedding']}")
+ item_ids = inputs['item_id']
+
+ latent_restored = 0
+ rqvae_loss = 0
+ clusters = []
+ remainder = latent_vector
+
+ for codebook in self.codebooks:
+ codebook_indices = self.get_codebook_indices(remainder, codebook)
+ clusters.append(codebook_indices)
+
+ quantized = codebook[codebook_indices]
+ codebook_vectors = remainder + (quantized - remainder).detach()
+
+ rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach())
+ rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach())
+
+ latent_restored += codebook_vectors
+ remainder = remainder - codebook_vectors
+
+ embeddings_restored = self.decoder(latent_restored)
+ recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding'])
+
+ if 'cooccurrence_embedding' in inputs:
+ # print(f"cooccurrence_embedding shape: {inputs['cooccurrence_embedding'].shape} device {inputs['cooccurrence_embedding'].device}" )
+ # print(f"latent_restored shape {latent_restored.shape} device {latent_restored.device}")
+ cooccurrence_latent = self.encoder(inputs['cooccurrence_embedding'].to(latent_restored.device))
+ cooccurrence_restored = self._quantize_representation(cooccurrence_latent)
+ con_loss = self.contrastive_loss(latent_restored, cooccurrence_restored)
+ else:
+ con_loss = torch.as_tensor(0.0, device=latent_vector.device)
+
+ loss = (
+ recon_loss
+ + self.quant_loss_weight * rqvae_loss
+ + self.contrastive_loss_weight * con_loss
+ ).mean()
+
+ clusters_counts = []
+ for cluster in clusters:
+ clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size))
+
+ return loss, {
+ 'loss': loss.item(),
+ 'recon_loss': recon_loss.mean().item(),
+ 'rqvae_loss': rqvae_loss.mean().item(),
+ 'con_loss': con_loss.item(),
+
+ 'clusters_counts': clusters_counts,
+ 'clusters': torch.stack(clusters).T,
+ 'embedding_hat': embeddings_restored,
+ }
\ No newline at end of file
diff --git a/scripts/plum-yambda/transforms.py b/scripts/plum-yambda/transforms.py
new file mode 100644
index 0000000..bdbfffa
--- /dev/null
+++ b/scripts/plum-yambda/transforms.py
@@ -0,0 +1,247 @@
+import numpy as np
+import pickle
+import torch
+import torch.nn.functional as F
+from typing import Dict, List
+from irec.data.base import BaseDataset
+from irec.data.transforms import Transform
+
+from cooc_data import CoocMappingDataset
+
+
+class AddWeightedCooccurrenceEmbeddings:
+ def __init__(self, cooccur_counts, item_id_to_embedding, all_item_ids):
+ self.cooccur_counts = cooccur_counts
+ self.item_id_to_embedding = item_id_to_embedding
+ self.all_item_ids = all_item_ids
+ self.call_count = 0
+
+ def __call__(self, batch):
+ self.call_count += 1
+ item_ids = batch['item_id']
+ cooccurrence_embeddings = []
+
+ for idx, item_id in enumerate(item_ids):
+ item_id_val = int(item_id.item()) if torch.is_tensor(item_id) else int(item_id)
+
+ counter = self.cooccur_counts.get(item_id_val)
+ if counter and len(counter) > 0:
+ cooc_ids, freqs = zip(*counter.items())
+ freqs_array = np.array(freqs, dtype=np.float32)
+ probs = freqs_array / freqs_array.sum()
+ cooc_id = np.random.choice(cooc_ids, p=probs)
+
+ else:
+ cooc_id = np.random.choice(self.all_item_ids)
+ if self.call_count % 500 == 0 and idx < 5:
+ print(f" idx={idx}: item_id={item_id_val} fallback random")
+
+ cooc_emb = self.item_id_to_embedding.get(cooc_id, batch['embedding'][0])
+ cooccurrence_embeddings.append(cooc_emb)
+
+ batch['cooccurrence_embedding'] = torch.stack(cooccurrence_embeddings)
+ return batch
+
+
+
+class AddWeightedCooccurrenceEmbeddingsCached:
+ def __init__(self, cooccur_counts, item_id_to_embedding, all_item_ids):
+ self.cooccur_counts = cooccur_counts
+ self.item_id_to_embedding = item_id_to_embedding
+ self.all_item_ids = all_item_ids
+ self.call_count = 0
+
+ self.cooc_probs_cache = {}
+ self._precompute_probabilities()
+
+ def _precompute_probabilities(self):
+ for item_id, counter in self.cooccur_counts.items():
+ if counter and len(counter) > 0:
+ cooc_ids, freqs = zip(*counter.items())
+ freqs_array = np.array(freqs, dtype=np.float32)
+ probs = freqs_array / freqs_array.sum()
+ self.cooc_probs_cache[item_id] = (cooc_ids, probs)
+
+ def __call__(self, batch):
+ self.call_count += 1
+ item_ids = batch['item_id']
+ cooccurrence_embeddings = []
+
+ for idx, item_id in enumerate(item_ids):
+ item_id_val = int(item_id.item()) if torch.is_tensor(item_id) else int(item_id)
+
+ if item_id_val in self.cooc_probs_cache:
+ cooc_ids, probs = self.cooc_probs_cache[item_id_val]
+ cooc_id = np.random.choice(cooc_ids, p=probs)
+ else:
+ cooc_id = np.random.choice(self.all_item_ids)
+ if self.call_count % 10 == 0 and idx < 5:
+ print(f" idx={idx}: item_id={item_id_val} fallback random")
+
+ cooc_emb = self.item_id_to_embedding.get(cooc_id, batch['embedding'][0])
+ cooccurrence_embeddings.append(cooc_emb)
+
+ batch['cooccurrence_embedding'] = torch.stack(cooccurrence_embeddings)
+ return batch
+
+class AddWeightedCooccurrenceEmbeddingsVectorized:
+
+ def __init__(
+ self,
+ cooccur_counts: Dict[int, Dict[int, int]],
+ item_id_to_embedding: Dict[int, torch.Tensor],
+ all_item_ids: List[int],
+ device: torch.device,
+ limit_neighbors: bool = True,
+ max_neighbors: int = 256
+ ):
+ """
+ limit_neighbors: если True, ограничиваем до max_neighbors (для экономии памяти)
+ max_neighbors: максимум соседей (используется только если limit_neighbors=True)
+ """
+ self.device = device
+ self.call_count = 0
+ self.limit_neighbors = limit_neighbors
+ self.max_neighbors = max_neighbors
+
+ max_item_id = max(item_id_to_embedding.keys())
+ embedding_dim = next(iter(item_id_to_embedding.values())).shape[0]
+
+ self.embedding_matrix = torch.zeros(
+ (max_item_id + 1, embedding_dim),
+ device=device,
+ dtype=torch.float32,
+ requires_grad=False
+ )
+
+ print("Building embedding matrix")
+ for item_id, emb in item_id_to_embedding.items():
+ if isinstance(emb, torch.Tensor):
+ self.embedding_matrix[item_id] = emb.detach()
+ else:
+ self.embedding_matrix[item_id] = torch.tensor(emb, device=device, dtype=torch.float32)
+
+ self.all_item_ids_tensor = torch.tensor(
+ all_item_ids,
+ device=device,
+ dtype=torch.long,
+ requires_grad=False
+ )
+
+ print("Building cooccurrence tables")
+ self._build_cooccurrence_tables(cooccur_counts, max_item_id, len(all_item_ids))
+
+ def _build_cooccurrence_tables(self, cooccur_counts: Dict, max_item_id: int, num_all_items: int):
+ """
+ - neighbors_matrix: [max_item_id+1, num_neighbors]
+ - probs_matrix: [max_item_id+1, num_neighbors]
+ Если у item_id нет соседей, neighbors и probs заполняются равномерно из all_items
+ """
+ neighbor_counts = {}
+ for item_id in range(max_item_id + 1):
+ counter = cooccur_counts.get(item_id)
+ if counter and len(counter) > 0:
+ num_neighbors = len(counter)
+ if self.limit_neighbors:
+ num_neighbors = min(num_neighbors, self.max_neighbors)
+ else:
+ num_neighbors = num_all_items
+
+ neighbor_counts[item_id] = num_neighbors
+
+ max_num_neighbors = max(neighbor_counts.values())
+ actual_max_neighbors = min(max_num_neighbors, self.max_neighbors) if self.limit_neighbors else max_num_neighbors
+
+ print(f"Max neighbors per item: {actual_max_neighbors}")
+
+ neighbors_matrix = torch.zeros(
+ (max_item_id + 1, actual_max_neighbors),
+ dtype=torch.long,
+ device=self.device,
+ requires_grad=False
+ )
+
+ probs_matrix = torch.zeros(
+ (max_item_id + 1, actual_max_neighbors),
+ dtype=torch.float32,
+ device=self.device,
+ requires_grad=False
+ )
+
+ num_items_with_cooc = 0
+
+ # Заполняем матрицы
+ for item_id in range(max_item_id + 1):
+ counter = cooccur_counts.get(item_id)
+
+ if counter and len(counter) > 0:
+ # === Есть соседи: используем реальные вероятности ===
+ num_items_with_cooc += 1
+
+ # Извлекаем соседей и их counts, сортируем по частоте
+ cooc_ids, freqs = zip(*sorted(counter.items(), key=lambda x: x[1], reverse=True))
+ cooc_ids = list(cooc_ids)
+ freqs = np.array(freqs, dtype=np.float32)
+
+ # Берем только топ
+ num_neighbors = min(len(cooc_ids), actual_max_neighbors)
+ cooc_ids = cooc_ids[:num_neighbors]
+ freqs = freqs[:num_neighbors]
+
+ # Нормализуем
+ probs = freqs / freqs.sum()
+
+ neighbors_matrix[item_id, :num_neighbors] = torch.tensor(
+ cooc_ids, dtype=torch.long, device=self.device
+ )
+ probs_matrix[item_id, :num_neighbors] = torch.tensor(
+ probs, dtype=torch.float32, device=self.device
+ )
+
+ else:
+ # Нет соседей: равномерное распределение на all_items
+ if actual_max_neighbors >= num_all_items:
+ # Можем поместить всех айтемов
+ neighbors_matrix[item_id, :num_all_items] = self.all_item_ids_tensor
+ probs_matrix[item_id, :num_all_items] = 1.0 / num_all_items
+ else:
+ # Выбираем случайное подмножество
+ indices = torch.randperm(num_all_items, device=self.device)[:actual_max_neighbors]
+ neighbors_matrix[item_id] = self.all_item_ids_tensor[indices]
+ probs_matrix[item_id] = 1.0 / actual_max_neighbors
+
+ self.neighbors_matrix = neighbors_matrix
+ self.probs_matrix = probs_matrix
+
+ print(f"Cooccurrence tables built: {num_items_with_cooc}/{max_item_id + 1} items have real neighbors")
+
+ def __call__(self, batch):
+ self.call_count += 1
+
+ item_ids = batch['item_id'] # [batch_size]
+ batch_size = item_ids.shape[0]
+
+ # Берем вероятности для items в батче
+ probs = self.probs_matrix[item_ids] # [batch_size, max_neighbors]
+
+ # Выбираем индекс соседа для каждого item
+ # torch.multinomial: выбирает из max_neighbors категорий по вероятностям
+ # Результат: [batch_size, 1] - индексы в диапазоне [0, max_neighbors)
+ neighbor_indices = torch.multinomial(probs, num_samples=1, replacement=True)
+ neighbor_indices = neighbor_indices.squeeze(1) # [batch_size]
+
+ # neighbors_matrix[item_ids, neighbor_indices] -> [batch_size]
+ cooc_ids = self.neighbors_matrix[item_ids, neighbor_indices]
+
+ # Lookup эмбеддингов
+ cooccurrence_embeddings = self.embedding_matrix[cooc_ids] # [batch_size, embedding_dim]
+
+ batch['cooccurrence_embedding'] = cooccurrence_embeddings
+
+ # if self.call_count % 500 == 0:
+ # print(
+ # f"Call #{self.call_count}: {batch_size} samples, "
+ # f"cooc_embeddings shape: {cooccurrence_embeddings.shape}"
+ # )
+
+ return batch
\ No newline at end of file
diff --git a/scripts/plum-yambda/yambda_4_1_train_plum.py b/scripts/plum-yambda/yambda_4_1_train_plum.py
new file mode 100644
index 0000000..86e0c2b
--- /dev/null
+++ b/scripts/plum-yambda/yambda_4_1_train_plum.py
@@ -0,0 +1,186 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddingsVectorized
+from cooc_data import CoocMappingDataset
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 35
+BATCH_SIZE = 1024
+
+MAX_NEIGHBOURS_COUNT = 1000
+
+INPUT_DIM = 128
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'4-1_filtered_yambda_gpu_week_ws_{WINDOW_SIZE}'
+INTER_TRAIN_PATH = "/home/jovyan/IRec/data/Yambda/week-splits/merged_for_exps_filtered/exp_4-1_0.9_inter_semantics_train.json" #отсекать старое (может и нет)
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet"
+IREC_PATH = '../../'
+
+print(INTER_TRAIN_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ data = CoocMappingDataset.create_from_split_part(
+ train_inter_json_path=INTER_TRAIN_PATH,
+ window_size=WINDOW_SIZE
+ )
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'], device=DEVICE)
+ all_item_ids.append(item_id)
+
+ add_cooc_transform = AddWeightedCooccurrenceEmbeddingsVectorized(
+ cooccur_counts=data.cooccur_counter_mapping,
+ item_id_to_embedding=item_id_to_embedding,
+ all_item_ids=all_item_ids,
+ device=DEVICE,
+ limit_neighbors=True,
+ max_neighbors = MAX_NEIGHBOURS_COUNT
+ )
+
+ train_dataloader = DataLoader( #call в основном потоке делается нужно исправить
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform
+ ).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform)
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.Profiler(
+ wait=10,
+ warmup=10,
+ active=10,
+ logdir=os.path.join(IREC_PATH, 'tensorboard_logs')
+ ),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum-yambda/yambda_infer_4.1_default.py b/scripts/plum-yambda/yambda_infer_4.1_default.py
new file mode 100644
index 0000000..1485fde
--- /dev/null
+++ b/scripts/plum-yambda/yambda_infer_4.1_default.py
@@ -0,0 +1,145 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import PlumRQVAE
+
+# ПУТИ
+IREC_PATH = '/home/jovyan/IRec/'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet"
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_filtered_yambda_gpu_quantile_ws_2_best_0.0026.pth'
+RESULTS_PATH = os.path.join(IREC_PATH, 'results_sigir_yambda')
+
+WINDOW_SIZE = 2
+EXPERIMENT_NAME = f'4-1_filtered_yambda_gpu_quantile_ws_{WINDOW_SIZE}'
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+BATCH_SIZE = 1024
+
+INPUT_DIM = 128
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+
+BETA = 0.25
+
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum/beauty-exps/4_1_train_plum.py b/scripts/plum/beauty-exps/4_1_train_plum.py
new file mode 100644
index 0000000..357fc19
--- /dev/null
+++ b/scripts/plum/beauty-exps/4_1_train_plum.py
@@ -0,0 +1,169 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDataset, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddings
+from cooc_data import CoocMappingDataset
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 500
+BATCH_SIZE = 1024
+
+INPUT_DIM = 4096
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'4-1_yambda_quantile_ws_{WINDOW_SIZE}'
+INTER_TRAIN_PATH = "/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json"
+EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl"
+IREC_PATH = '../../'
+
+print(INTER_TRAIN_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ data = CoocMappingDataset.create_from_split_part(
+ train_inter_json_path=INTER_TRAIN_PATH,
+ window_size=WINDOW_SIZE
+ )
+
+ dataset = EmbeddingDataset(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
+ data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
+
+ train_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum/beauty-exps/4_2_train_plum.py b/scripts/plum/beauty-exps/4_2_train_plum.py
new file mode 100644
index 0000000..96cfda9
--- /dev/null
+++ b/scripts/plum/beauty-exps/4_2_train_plum.py
@@ -0,0 +1,169 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDataset, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddings
+from cooc_data import CoocMappingDataset
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 500
+BATCH_SIZE = 1024
+
+INPUT_DIM = 4096
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'4-2_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json"
+EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl"
+IREC_PATH = '../../'
+
+print(INTER_TRAIN_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ data = CoocMappingDataset.create_from_split_part(
+ train_inter_json_path=INTER_TRAIN_PATH,
+ window_size=WINDOW_SIZE
+ )
+
+ dataset = EmbeddingDataset(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
+ data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
+
+ train_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum/beauty-exps/4_3_train_plum.py b/scripts/plum/beauty-exps/4_3_train_plum.py
new file mode 100644
index 0000000..ac6cfb6
--- /dev/null
+++ b/scripts/plum/beauty-exps/4_3_train_plum.py
@@ -0,0 +1,169 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDataset, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddings
+from cooc_data import CoocMappingDataset
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 500
+BATCH_SIZE = 1024
+
+INPUT_DIM = 4096
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'4-3_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-3_0.95_inter_semantics_train.json"
+EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl"
+IREC_PATH = '../../'
+
+print(INTER_TRAIN_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ data = CoocMappingDataset.create_from_split_part(
+ train_inter_json_path=INTER_TRAIN_PATH,
+ window_size=WINDOW_SIZE
+ )
+
+ dataset = EmbeddingDataset(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
+ data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
+
+ train_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum/beauty-exps/infer_4.1_default.py b/scripts/plum/beauty-exps/infer_4.1_default.py
new file mode 100644
index 0000000..fff61d3
--- /dev/null
+++ b/scripts/plum/beauty-exps/infer_4.1_default.py
@@ -0,0 +1,145 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+
+from data import EmbeddingDataset, ProcessEmbeddings
+from models import PlumRQVAE
+
+# ПУТИ
+IREC_PATH = '/home/jovyan/IRec/'
+EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl'
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_updated_quantile_plum_rqvae_beauty_ws_2_best_0.0052.pth'
+RESULTS_PATH = os.path.join(IREC_PATH, 'results_sigir')
+
+WINDOW_SIZE = 2
+EXPERIMENT_NAME = f'4-1_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+BATCH_SIZE = 1024
+
+INPUT_DIM = 4096
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+
+BETA = 0.25
+
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDataset(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum/beauty-exps/infer_4.2_default.py b/scripts/plum/beauty-exps/infer_4.2_default.py
new file mode 100644
index 0000000..c5c7c02
--- /dev/null
+++ b/scripts/plum/beauty-exps/infer_4.2_default.py
@@ -0,0 +1,145 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+
+from data import EmbeddingDataset, ProcessEmbeddings
+from models import PlumRQVAE
+
+# ПУТИ
+IREC_PATH = '/home/jovyan/IRec/'
+EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl'
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-2_updated_quantile_plum_rqvae_beauty_ws_2_best_0.0051.pth'
+RESULTS_PATH = os.path.join(IREC_PATH, 'results_sigir')
+
+WINDOW_SIZE = 2
+EXPERIMENT_NAME = f'4-2_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+BATCH_SIZE = 1024
+
+INPUT_DIM = 4096
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+
+BETA = 0.25
+
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDataset(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum/beauty-exps/infer_4.3_default.py b/scripts/plum/beauty-exps/infer_4.3_default.py
new file mode 100644
index 0000000..c7fca80
--- /dev/null
+++ b/scripts/plum/beauty-exps/infer_4.3_default.py
@@ -0,0 +1,145 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+
+from data import EmbeddingDataset, ProcessEmbeddings
+from models import PlumRQVAE
+
+# ПУТИ
+IREC_PATH = '/home/jovyan/IRec/'
+EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl'
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-3_updated_quantile_plum_rqvae_beauty_ws_2_best_0.005.pth'
+RESULTS_PATH = os.path.join(IREC_PATH, 'results_sigir')
+
+WINDOW_SIZE = 2
+EXPERIMENT_NAME = f'4-3_updated_quantile_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+BATCH_SIZE = 1024
+
+INPUT_DIM = 4096
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+
+BETA = 0.25
+
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDataset(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum/beauty-exps/infer_default.py b/scripts/plum/beauty-exps/infer_default.py
new file mode 100644
index 0000000..af8df34
--- /dev/null
+++ b/scripts/plum/beauty-exps/infer_default.py
@@ -0,0 +1,152 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+
+from data import EmbeddingDataset, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddings
+from cooc_data import CoocMappingDataset
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+BATCH_SIZE = 1024
+
+INPUT_DIM = 4096
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+
+BETA = 0.25
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/test_plum_rqvae_beauty_ws_2_best_0.0054.pth'
+
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+
+IREC_PATH = '/home/jovyan/IRec/'
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ data = CoocMappingDataset.create(
+ inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'),
+ max_sequence_length=20,
+ sampler_type='sasrec',
+ window_size=WINDOW_SIZE
+ )
+
+ dataset = EmbeddingDataset(
+ data_path='/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl'
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
+ data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json',
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ with open(os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum/cooc_data.py b/scripts/plum/cooc_data.py
index b11e6f0..50f2bdd 100644
--- a/scripts/plum/cooc_data.py
+++ b/scripts/plum/cooc_data.py
@@ -13,21 +13,15 @@ class CoocMappingDataset:
def __init__(
self,
train_sampler,
- validation_sampler,
- test_sampler,
num_items,
- max_sequence_length,
cooccur_counter_mapping=None
):
self._train_sampler = train_sampler
- self._validation_sampler = validation_sampler
- self._test_sampler = test_sampler
self._num_items = num_items
- self._max_sequence_length = max_sequence_length
self._cooccur_counter_mapping = cooccur_counter_mapping
@classmethod
- def create(cls, inter_json_path, max_sequence_length, sampler_type, window_size):
+ def create(cls, inter_json_path, window_size):
max_item_id = 0
train_dataset, validation_dataset, test_dataset = [], [], []
@@ -43,31 +37,59 @@ def create(cls, inter_json_path, max_sequence_length, sampler_type, window_size)
'user.ids': [user_id],
'item.ids': item_ids[:-2],
})
- validation_dataset.append({
- 'user.ids': [user_id],
- 'item.ids': item_ids[:-1],
- })
- test_dataset.append({
- 'user.ids': [user_id],
- 'item.ids': item_ids,
- })
cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size)
logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}')
train_sampler = train_dataset
- validation_sampler = validation_dataset
- test_sampler = test_dataset
return cls(
train_sampler=train_sampler,
- validation_sampler=validation_sampler,
- test_sampler=test_sampler,
num_items=max_item_id + 1,
- max_sequence_length=max_sequence_length,
cooccur_counter_mapping=cooccur_counter_mapping
)
+ @classmethod
+ def create_from_split_part(
+ cls,
+ train_inter_json_path,
+ window_size
+ ):
+
+ max_item_id = 0
+ train_dataset = []
+
+ with open(train_inter_json_path, 'r') as f:
+ train_interactions = json.load(f)
+
+ # Обрабатываем TRAIN
+ for user_id_str, item_ids in train_interactions.items():
+ user_id = int(user_id_str)
+ if item_ids:
+ max_item_id = max(max_item_id, max(item_ids))
+
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids,
+ })
+
+ logger.debug(f'Train: {len(train_dataset)} users')
+ logger.debug(f'Max item ID: {max_item_id}')
+
+ cooccur_counter_mapping = cls.build_cooccur_counter_mapping(
+ train_dataset,
+ window_size=window_size
+ )
+
+ logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items')
+
+ return cls(
+ train_sampler=train_dataset,
+ num_items=max_item_id + 1,
+ cooccur_counter_mapping=cooccur_counter_mapping
+ )
+
+
@staticmethod
def build_cooccur_counter_mapping(train_dataset, window_size): #TODO передавать время и по нему строить окно
cooccur_counts = defaultdict(Counter)
@@ -80,16 +102,6 @@ def build_cooccur_counter_mapping(train_dataset, window_size): #TODO перед
cooccur_counts[item_i][items[j]] += 1
return cooccur_counts
- def get_datasets(self):
- return self._train_sampler, self._validation_sampler, self._test_sampler
-
- @property
- def num_items(self):
- return self._num_items
-
- @property
- def max_sequence_length(self):
- return self._max_sequence_length
@property
def cooccur_counter_mapping(self):
diff --git a/scripts/plum/data.py b/scripts/plum/data.py
index 0ffef82..9c15b70 100644
--- a/scripts/plum/data.py
+++ b/scripts/plum/data.py
@@ -5,6 +5,29 @@
from irec.data.transforms import Transform
+import polars as pl
+import torch
+
+class EmbeddingDatasetParquet(BaseDataset):
+ def __init__(self, data_path):
+ self.df = pl.read_parquet(data_path)
+ self.item_ids = np.array(self.df['item_id'], dtype=np.int64)
+ self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32)
+ print(f"embedding dim: {self.embeddings[0].shape}")
+
+ def __getitem__(self, idx):
+ index = self.item_ids[idx]
+ tensor_emb = self.embeddings[idx]
+ return {
+ 'item_id': index,
+ 'embedding': tensor_emb,
+ 'embedding_dim': len(tensor_emb)
+ }
+
+ def __len__(self):
+ return len(self.embeddings)
+
+
class EmbeddingDataset(BaseDataset):
def __init__(self, data_path):
self.data_path = data_path
diff --git a/scripts/plum/infer_default.py b/scripts/plum/infer_default.py
index af8df34..b15fb6d 100644
--- a/scripts/plum/infer_default.py
+++ b/scripts/plum/infer_default.py
@@ -12,8 +12,18 @@
from data import EmbeddingDataset, ProcessEmbeddings
from models import PlumRQVAE
-from transforms import AddWeightedCooccurrenceEmbeddings
-from cooc_data import CoocMappingDataset
+
+# ПУТИ
+IREC_PATH = '/home/jovyan/IRec/'
+EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl'
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_plum_rqvae_beauty_ws_2_best_0.0051.pth'
+RESULTS_PATH = os.path.join(IREC_PATH, 'results')
+
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+
+# ОСТАЛЬНОЕ
SEED_VALUE = 42
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
@@ -26,29 +36,16 @@
NUM_CODEBOOKS = 3
BETA = 0.25
-MODEL_PATH = '/home/jovyan/IRec/checkpoints/test_plum_rqvae_beauty_ws_2_best_0.0054.pth'
-WINDOW_SIZE = 2
-
-EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
-
-IREC_PATH = '/home/jovyan/IRec/'
def main():
fix_random_seed(SEED_VALUE)
- data = CoocMappingDataset.create(
- inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'),
- max_sequence_length=20,
- sampler_type='sasrec',
- window_size=WINDOW_SIZE
- )
-
dataset = EmbeddingDataset(
- data_path='/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl'
+ data_path=EMBEDDINGS_PATH
)
-
+
item_id_to_embedding = {}
all_item_ids = []
for idx in range(len(dataset)):
@@ -57,15 +54,12 @@ def main():
item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
all_item_ids.append(item_id)
- add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
- data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
-
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=False,
drop_last=False,
- ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
model = PlumRQVAE(
input_dim=INPUT_DIM,
@@ -106,8 +100,8 @@ def main():
cb.Logger().every_num_steps(len(dataloader)),
cb.InferenceSaver(
- metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
- save_path=f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json',
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
format='json'
)
]
@@ -125,9 +119,9 @@ def main():
from collections import defaultdict
import numpy as np
- with open(f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', 'r') as f:
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
mappings = json.load(f)
-
+
inter = {}
sem_2_ids = defaultdict(list)
for mapping in mappings:
@@ -143,8 +137,8 @@ def main():
inter[item_id].append(collision_solver)
for i in range(len(inter[item_id])):
inter[item_id][i] += CODEBOOK_SIZE * i
-
- with open(os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
json.dump(inter, f, indent=2)
diff --git a/scripts/plum/train_plum.py b/scripts/plum/train_plum.py
index 5a00bc3..ffa9e43 100644
--- a/scripts/plum/train_plum.py
+++ b/scripts/plum/train_plum.py
@@ -41,8 +41,6 @@ def main():
data = CoocMappingDataset.create(
inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'),
- max_sequence_length=20,
- sampler_type='sasrec',
window_size=WINDOW_SIZE
)
diff --git a/scripts/plum/train_plum_timestamp_based.py b/scripts/plum/train_plum_timestamp_based.py
new file mode 100644
index 0000000..e755d95
--- /dev/null
+++ b/scripts/plum/train_plum_timestamp_based.py
@@ -0,0 +1,168 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDataset, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddings
+from cooc_data import CoocMappingDataset
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 500
+BATCH_SIZE = 1024
+
+INPUT_DIM = 4096
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'4-1_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json"
+EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl"
+IREC_PATH = '../../'
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ data = CoocMappingDataset.create_from_split_part(
+ train_inter_json_path=INTER_TRAIN_PATH,
+ window_size=WINDOW_SIZE
+ )
+
+ dataset = EmbeddingDataset(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
+ data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
+
+ train_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/plum/transforms.py b/scripts/plum/transforms.py
index 0af1dda..8887115 100644
--- a/scripts/plum/transforms.py
+++ b/scripts/plum/transforms.py
@@ -2,12 +2,6 @@
import pickle
import torch
-from irec.data.base import BaseDataset
-from irec.data.transforms import Transform
-
-from cooc_data import CoocMappingDataset
-
-
class AddWeightedCooccurrenceEmbeddings:
def __init__(self, cooccur_counts, item_id_to_embedding, all_item_ids):
self.cooccur_counts = cooccur_counts
@@ -32,7 +26,7 @@ def __call__(self, batch):
else:
cooc_id = np.random.choice(self.all_item_ids)
- if self.call_count % 10 == 0 and idx < 5:
+ if self.call_count % 500 == 0 and idx < 5:
print(f" idx={idx}: item_id={item_id_val} fallback random")
cooc_emb = self.item_id_to_embedding.get(cooc_id, batch['embedding'][0])
diff --git a/scripts/rqvae-yambda/callbacks.py b/scripts/rqvae-yambda/callbacks.py
new file mode 100644
index 0000000..43ec460
--- /dev/null
+++ b/scripts/rqvae-yambda/callbacks.py
@@ -0,0 +1,64 @@
+import torch
+
+import irec.callbacks as cb
+from irec.runners import TrainingRunner, TrainingRunnerContext
+
+class InitCodebooks(cb.TrainingCallback):
+ def __init__(self, dataloader):
+ super().__init__()
+ self._dataloader = dataloader
+
+ @torch.no_grad()
+ def before_run(self, runner: TrainingRunner):
+ for i in range(len(runner.model.codebooks)):
+ X = next(iter(self._dataloader))['embedding']
+ idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])]
+ remainder = runner.model.encoder(X[idx])
+
+ for j in range(i):
+ codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j])
+ codebook_vectors = runner.model.codebooks[j][codebook_indices]
+ remainder = remainder - codebook_vectors
+
+ runner.model.codebooks[i].data = remainder.detach()
+
+
+class FixDeadCentroids(cb.TrainingCallback):
+ def __init__(self, dataloader):
+ super().__init__()
+ self._dataloader = dataloader
+
+ def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext):
+ for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)):
+ context.metrics[f'num_dead/{i}'] = num_fixed
+
+ @torch.no_grad()
+ def fix_dead_codebooks(self, runner: TrainingRunner):
+ num_fixed = []
+ for codebook_idx, codebook in enumerate(runner.model.codebooks):
+ centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device)
+ random_batch = next(iter(self._dataloader))['embedding']
+
+ for batch in self._dataloader:
+ remainder = runner.model.encoder(batch['embedding'])
+ for l in range(codebook_idx):
+ ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
+ remainder = remainder - runner.model.codebooks[l][ind]
+
+ indices = runner.model.get_codebook_indices(remainder, codebook)
+ centroid_counts.scatter_add_(0, indices, torch.ones_like(indices))
+
+ dead_mask = (centroid_counts == 0)
+ num_dead = int(dead_mask.sum().item())
+ num_fixed.append(num_dead)
+ if num_dead == 0:
+ continue
+
+ remainder = runner.model.encoder(random_batch)
+ for l in range(codebook_idx):
+ ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l])
+ remainder = remainder - runner.model.codebooks[l][ind]
+ remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead]
+ codebook[dead_mask] = remainder.detach()
+
+ return num_fixed
diff --git a/scripts/rqvae-yambda/data.py b/scripts/rqvae-yambda/data.py
new file mode 100644
index 0000000..6c213ee
--- /dev/null
+++ b/scripts/rqvae-yambda/data.py
@@ -0,0 +1,35 @@
+import numpy as np
+import polars as pl
+
+from irec.data.base import BaseDataset
+from irec.data.transforms import Transform
+
+
+class EmbeddingDatasetParquet(BaseDataset):
+ def __init__(self, data_path):
+ self.df = pl.read_parquet(data_path)
+ self.item_ids = np.array(self.df['item_id'], dtype=np.int64)
+ self.embeddings = np.array(self.df['embedding'].to_list(), dtype=np.float32)
+ print(f"embedding dim: {self.embeddings[0].shape}")
+
+ def __getitem__(self, idx):
+ index = self.item_ids[idx]
+ tensor_emb = self.embeddings[idx]
+ return {
+ 'item_id': index,
+ 'embedding': tensor_emb,
+ 'embedding_dim': len(tensor_emb)
+ }
+
+ def __len__(self):
+ return len(self.embeddings)
+
+class ProcessEmbeddings(Transform):
+ def __init__(self, embedding_dim, keys):
+ self.embedding_dim = embedding_dim
+ self.keys = keys
+
+ def __call__(self, batch):
+ for key in self.keys:
+ batch[key] = batch[key].reshape(-1, self.embedding_dim)
+ return batch
\ No newline at end of file
diff --git a/scripts/rqvae-yambda/infer_yambda.py b/scripts/rqvae-yambda/infer_yambda.py
new file mode 100644
index 0000000..7daf42f
--- /dev/null
+++ b/scripts/rqvae-yambda/infer_yambda.py
@@ -0,0 +1,128 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import InferenceRunner
+
+from irec.utils import fix_random_seed
+
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import RQVAE
+
+
+IREC_PATH = '/home/jovyan/IRec/'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet"
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/rqvae_yambda_hd_128_cz_512_best_0.0014.pth'
+RESULTS_PATH = '/home/jovyan/IRec/rqvae-yambda-sem-ids'
+EXPERIMENT_NAME = 'rqvae_yambda_hd_128_cz_512'
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+BATCH_SIZE = 1024
+
+INPUT_DIM = 128
+HIDDEN_DIM = 128
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+
+BETA = 0.25
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ model = RQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ callbacks = [
+ cb.LoadModel(MODEL_PATH),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ }, name='valid'),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ },
+ ),
+
+ cb.Logger().every_num_steps(len(dataloader)),
+
+ cb.InferenceSaver(
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
+ format='json'
+ )
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = InferenceRunner(
+ model=model,
+ dataset=dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+ import json
+ from collections import defaultdict
+ import numpy as np
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
+ mappings = json.load(f)
+
+ inter = {}
+ sem_2_ids = defaultdict(list)
+ for mapping in mappings:
+ item_id = mapping['item_id']
+ clusters = mapping['clusters']
+ inter[int(item_id)] = clusters
+ sem_2_ids[tuple(clusters)].append(int(item_id))
+
+ for semantics, items in sem_2_ids.items():
+ assert len(items) <= CODEBOOK_SIZE, str(len(items))
+ collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist()
+ for item_id, collision_solver in zip(items, collision_solvers):
+ inter[item_id].append(collision_solver)
+ for i in range(len(inter[item_id])):
+ inter[item_id][i] += CODEBOOK_SIZE * i
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+ json.dump(inter, f, indent=2)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/rqvae-yambda/models.py b/scripts/rqvae-yambda/models.py
new file mode 100644
index 0000000..87c2241
--- /dev/null
+++ b/scripts/rqvae-yambda/models.py
@@ -0,0 +1,91 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class RQVAE(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ num_codebooks,
+ codebook_size,
+ embedding_dim,
+ beta=0.25,
+ quant_loss_weight=1.0,
+ ):
+ super().__init__()
+ self.register_buffer('beta', torch.tensor(beta))
+
+ self.input_dim = input_dim
+ self.num_codebooks = num_codebooks
+ self.codebook_size = codebook_size
+ self.embedding_dim = embedding_dim
+ self.quant_loss_weight = quant_loss_weight
+
+
+ self.encoder = self.make_encoding_tower(input_dim, embedding_dim)
+ self.decoder = self.make_encoding_tower(embedding_dim, input_dim)
+
+ self.codebooks = torch.nn.ParameterList()
+ for _ in range(num_codebooks):
+ cb = torch.FloatTensor(codebook_size, embedding_dim)
+ #nn.init.normal_(cb)
+ self.codebooks.append(cb)
+
+ @staticmethod
+ def make_encoding_tower(d1, d2, bias=False):
+ return torch.nn.Sequential(
+ nn.Linear(d1, d1),
+ nn.ReLU(),
+ nn.Linear(d1, d2),
+ nn.ReLU(),
+ nn.Linear(d2, d2, bias=bias)
+ )
+
+ @staticmethod
+ def get_codebook_indices(remainder, codebook):
+ dist = torch.cdist(remainder, codebook)
+ return dist.argmin(dim=-1)
+
+ def forward(self, inputs):
+ latent_vector = self.encoder(inputs['embedding'])
+ item_ids = inputs['item_id']
+
+ latent_restored = 0
+ rqvae_loss = 0
+ clusters = []
+ remainder = latent_vector
+
+ for codebook in self.codebooks:
+ codebook_indices = self.get_codebook_indices(remainder, codebook)
+ clusters.append(codebook_indices)
+
+ quantized = codebook[codebook_indices]
+ codebook_vectors = remainder + (quantized - remainder).detach()
+
+ rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach())
+ rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach())
+
+ latent_restored += codebook_vectors
+ remainder = remainder - codebook_vectors
+
+ embeddings_restored = self.decoder(latent_restored)
+ recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding'])
+
+ loss = (
+ recon_loss
+ + self.quant_loss_weight * rqvae_loss
+ ).mean()
+
+ clusters_counts = []
+ for cluster in clusters:
+ clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size))
+
+ return loss, {
+ 'loss': loss.item(),
+ 'recon_loss': recon_loss.mean().item(),
+ 'rqvae_loss': rqvae_loss.mean().item(),
+
+ 'clusters_counts': clusters_counts,
+ 'clusters': torch.stack(clusters).T,
+ 'embedding_hat': embeddings_restored,
+ }
\ No newline at end of file
diff --git a/scripts/rqvae-yambda/train_yambda.py b/scripts/rqvae-yambda/train_yambda.py
new file mode 100644
index 0000000..71582ae
--- /dev/null
+++ b/scripts/rqvae-yambda/train_yambda.py
@@ -0,0 +1,151 @@
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDatasetParquet, ProcessEmbeddings
+from models import RQVAE
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 100
+BATCH_SIZE = 1024
+
+INPUT_DIM = 128
+HIDDEN_DIM = 128
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+
+EXPERIMENT_NAME = 'rqvae_yambda_hd_128_cz_512'
+EMBEDDINGS_PATH = "/home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet"
+IREC_PATH = '../../'
+
+print(EXPERIMENT_NAME, EMBEDDINGS_PATH)
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ dataset = EmbeddingDatasetParquet(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ train_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ num_workers=8,
+ shuffle=True,
+ drop_last=True,
+ persistent_workers=True,
+ pin_memory=True
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = RQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.Profiler(
+ wait=10,
+ warmup=10,
+ active=10,
+ logdir=os.path.join(IREC_PATH, 'tensorboard_logs')
+ ),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/data.py b/scripts/tiger-lsvd/data.py
new file mode 100644
index 0000000..26123db
--- /dev/null
+++ b/scripts/tiger-lsvd/data.py
@@ -0,0 +1,428 @@
+from collections import defaultdict
+import json
+from loguru import logger
+import numpy as np
+from pathlib import Path
+
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+import polars as pl
+from irec.data.base import BaseDataset
+
+
+class InteractionsDatasetParquet(BaseDataset):
+ def __init__(self, data_path, max_items=None):
+ self.df = pl.read_parquet(data_path)
+ assert 'uid' in self.df.columns, "Missing 'uid' column"
+ assert 'item_ids' in self.df.columns, "Missing 'item_ids' column"
+ print(f"Dataset loaded: {len(self.df)} users")
+
+ if max_items is not None:
+ self.df = self.df.with_columns(
+ pl.col("item_ids").list.slice(-max_items).alias("item_ids")
+ )
+
+ def __getitem__(self, idx):
+ row = self.df.row(idx, named=True)
+ return {
+ 'user_id': row['uid'],
+ 'item_ids': np.array(row['item_ids'], dtype=np.uint32),
+ }
+
+ def __len__(self):
+ return len(self.df)
+
+ def __iter__(self):
+ for idx in range(len(self)):
+ yield self[idx]
+
+
+
+class Dataset:
+ def __init__(
+ self,
+ train_sampler,
+ validation_sampler,
+ test_sampler,
+ num_items,
+ max_sequence_length
+ ):
+ self._train_sampler = train_sampler
+ self._validation_sampler = validation_sampler
+ self._test_sampler = test_sampler
+ self._num_items = num_items
+ self._max_sequence_length = max_sequence_length
+
+ @classmethod
+ def create_timestamp_based_parquet(
+ cls,
+ train_parquet_path,
+ validation_parquet_path,
+ test_parquet_path,
+ max_sequence_length,
+ sampler_type,
+ min_sample_len=2,
+ is_extended=False,
+ max_train_events=50
+ ):
+ """
+ Загружает данные из parquet файлов с timestamp-based сплитом.
+
+ Ожидает структуру parquet:
+ - uid: int (user id)
+ - item_ids: list[int] (список item ids)
+
+ Аналогично create_timestamp_based, но для parquet формата.
+ """
+ max_item_id = 0
+ train_dataset, validation_dataset, test_dataset = [], [], []
+
+ print(f"started to load datasets from parquet with max train length {max_train_events}")
+
+ # Загружаем parquet файлы
+ train_df = pl.read_parquet(train_parquet_path)
+ validation_df = pl.read_parquet(validation_parquet_path)
+ test_df = pl.read_parquet(test_parquet_path)
+
+ # Проверяем наличие необходимых колонок
+ for df, name in [(train_df, "train"), (validation_df, "validation"), (test_df, "test")]:
+ assert 'uid' in df.columns, f"Missing 'uid' column in {name}"
+ assert 'item_ids' in df.columns, f"Missing 'item_ids' column in {name}"
+
+ # Создаем словари для быстрого доступа
+ train_data = {str(row['uid']): row['item_ids'] for row in train_df.iter_rows(named=True)}
+ validation_data = {str(row['uid']): row['item_ids'] for row in validation_df.iter_rows(named=True)}
+ test_data = {str(row['uid']): row['item_ids'] for row in test_df.iter_rows(named=True)}
+
+ all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys())
+ print(f"all users count: {len(all_users)}")
+
+ us_count = 0
+ for user_id_str in all_users:
+ if us_count % 100 == 0:
+ print(f"user id {us_count}/{len(all_users)}: {user_id_str}")
+
+ user_id = int(user_id_str)
+
+ # Получаем последовательности для каждого сплита
+ train_items = list(train_data.get(user_id_str, []))
+ validation_items = list(validation_data.get(user_id_str, []))
+ test_items = list(test_data.get(user_id_str, []))
+
+ # Обрезаем train на последние max_train_events событий
+ train_items = train_items[-max_train_events:] if len(train_items) > max_train_events else train_items
+
+ full_sequence = train_items + validation_items + test_items
+ if full_sequence:
+ max_item_id = max(max_item_id, max(full_sequence))
+
+ if us_count % 100 == 0:
+ print(f"full sequence len: {len(full_sequence)}")
+
+ us_count += 1
+ if len(full_sequence) < 4:
+ print(f'Core-4 dataset is used, user {user_id} has only {len(full_sequence)} items')
+ continue
+
+ if is_extended:
+ # sample = [1, 2]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ # sample = [1, 2, 3, 4, 5, 6, 7]
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ for prefix_length in range(min_sample_len, len(train_items) + 1):
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': train_items[:prefix_length],
+ })
+ else:
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': train_items,
+ })
+
+ # валидация
+
+ # разворачиваем каждый айтем из валидации в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+
+ current_history = train_items.copy()
+ valid_small_history = 0
+ for item in validation_items:
+ # эвал датасет сам отрезает таргет потом
+ sample_sequence = current_history + [item]
+
+ if len(sample_sequence) >= min_sample_len:
+ validation_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': sample_sequence,
+ })
+ else:
+ valid_small_history += 1
+ current_history.append(item)
+
+ # разворачиваем каждый айтем из теста в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ current_history = train_items + validation_items
+ test_small_history = 0
+ for item in test_items:
+ sample_sequence = current_history + [item]
+ if len(sample_sequence) >= min_sample_len:
+ test_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': sample_sequence,
+ })
+ else:
+ test_small_history += 1
+ current_history.append(item)
+
+ print(f"Train dataset size: {len(train_dataset)}")
+ print(f"Validation dataset size: {len(validation_dataset)} with skipped {valid_small_history}")
+ print(f"Test dataset size: {len(test_dataset)} with skipped {test_small_history}")
+
+ logger.debug(f'Train dataset size: {len(train_dataset)}')
+ logger.debug(f'Validation dataset size: {len(validation_dataset)}')
+ logger.debug(f'Test dataset size: {len(test_dataset)}')
+
+ train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length)
+ validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length)
+ test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length)
+
+ return cls(
+ train_sampler=train_sampler,
+ validation_sampler=validation_sampler,
+ test_sampler=test_sampler,
+ num_items=max_item_id + 1, # +1 added because our ids are 0-indexed
+ max_sequence_length=max_sequence_length
+ )
+
+ @classmethod
+ def create(cls, inter_json_path, max_sequence_length, sampler_type, is_extended=False):
+ max_item_id = 0
+ train_dataset, validation_dataset, test_dataset = [], [], []
+
+ with open(inter_json_path, 'r') as f:
+ user_interactions = json.load(f)
+
+ for user_id_str, item_ids in user_interactions.items():
+ user_id = int(user_id_str)
+
+ if item_ids:
+ max_item_id = max(max_item_id, max(item_ids))
+
+ assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items'
+
+ # sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] (leave one out scheme, 8 - train, 9 - valid, 10 - test)
+ if is_extended:
+ # sample = [1, 2]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ # sample = [1, 2, 3, 4, 5, 6, 7]
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ for prefix_length in range(2, len(item_ids) - 2 + 1):
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids[:prefix_length],
+ })
+ else:
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids[:-2],
+ })
+
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9]
+ validation_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids[:-1],
+ })
+
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ test_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids,
+ })
+
+ logger.debug(f'Train dataset size: {len(train_dataset)}')
+ logger.debug(f'Validation dataset size: {len(validation_dataset)}')
+ logger.debug(f'Test dataset size: {len(test_dataset)}')
+ logger.debug(f'Max item id: {max_item_id}')
+
+ train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length)
+ validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length)
+ test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length)
+
+ return cls(
+ train_sampler=train_sampler,
+ validation_sampler=validation_sampler,
+ test_sampler=test_sampler,
+ num_items=max_item_id + 1, # +1 added because our ids are 0-indexed
+ max_sequence_length=max_sequence_length
+ )
+
+ def get_datasets(self):
+ return self._train_sampler, self._validation_sampler, self._test_sampler
+
+ @property
+ def num_items(self):
+ return self._num_items
+
+ @property
+ def max_sequence_length(self):
+ return self._max_sequence_length
+
+
+class TrainDataset(BaseDataset):
+ def __init__(self, dataset, prediction_type, max_sequence_length):
+ self._dataset = dataset
+ self._prediction_type = prediction_type
+ self._max_sequence_length = max_sequence_length
+
+ self._transforms = {
+ 'sasrec': self._all_items_transform,
+ 'tiger': self._last_item_transform
+ }
+
+ def _all_items_transform(self, sample):
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ next_item_sequence = sample['item.ids'][-self._max_sequence_length:][1:]
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array(next_item_sequence, dtype=np.int64),
+ 'labels.length': np.array([len(next_item_sequence)], dtype=np.int64)
+ }
+
+ def _last_item_transform(self, sample):
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ last_item = sample['item.ids'][-self._max_sequence_length:][-1]
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array([last_item], dtype=np.int64),
+ 'labels.length': np.array([1], dtype=np.int64),
+ }
+
+ def __getitem__(self, index):
+ return self._transforms[self._prediction_type](self._dataset[index])
+
+ def __len__(self):
+ return len(self._dataset)
+
+
+class EvalDataset(BaseDataset):
+ def __init__(self, dataset, max_sequence_length):
+ self._dataset = dataset
+ self._max_sequence_length = max_sequence_length
+
+ @property
+ def dataset(self):
+ return self._dataset
+
+ def __len__(self):
+ return len(self._dataset)
+
+ def __getitem__(self, index):
+ sample = self._dataset[index]
+
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ next_item = sample['item.ids'][-self._max_sequence_length:][-1]
+
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array([next_item], dtype=np.int64),
+ 'labels.length': np.array([1], dtype=np.int64),
+ 'visited.ids': np.array(sample['item.ids'][:-1], dtype=np.int64),
+ 'visited.length': np.array([len(sample['item.ids'][:-1])], dtype=np.int64),
+ }
+
+
+class ArrowBatchDataset(BaseDataset):
+ def __init__(self, batch_dir, device='cuda', preload=False):
+ self.batch_dir = Path(batch_dir)
+ self.device = device
+
+ all_files = list(self.batch_dir.glob('batch_*_len_*.arrow'))
+
+ batch_files_map = defaultdict(list)
+ for f in all_files:
+ batch_id = int(f.stem.split('_')[1])
+ batch_files_map[batch_id].append(f)
+
+ for batch_id in batch_files_map:
+ batch_files_map[batch_id].sort()
+
+ self.batch_indices = sorted(batch_files_map.keys())
+
+ if preload:
+ print(f"Preloading {len(self.batch_indices)} batches...")
+ self.cached_batches = []
+
+ for idx in range(len(self.batch_indices)):
+ batch = self._load_batch(batch_files_map[self.batch_indices[idx]])
+ self.cached_batches.append(batch)
+ else:
+ self.cached_batches = None
+ self.batch_files_map = batch_files_map
+
+ def _load_batch(self, arrow_files):
+ batch = {}
+
+ for arrow_file in arrow_files:
+ table = feather.read_table(arrow_file)
+ metadata = table.schema.metadata or {}
+
+ for col_name in table.column_names:
+ col = table.column(col_name)
+
+ shape_key = f'{col_name}_shape'
+ dtype_key = f'{col_name}_dtype'
+
+ if shape_key.encode() in metadata:
+ shape = eval(metadata[shape_key.encode()].decode())
+ dtype = np.dtype(metadata[dtype_key.encode()].decode())
+
+ # Проверяем тип колонки
+ if pa.types.is_list(col.type) or pa.types.is_large_list(col.type):
+ arr = np.array(col.to_pylist(), dtype=dtype)
+ else:
+ arr = col.to_numpy().reshape(shape).astype(dtype)
+ else:
+ if pa.types.is_list(col.type) or pa.types.is_large_list(col.type):
+ arr = np.array(col.to_pylist())
+ else:
+ arr = col.to_numpy()
+
+ batch[col_name] = torch.from_numpy(arr.copy()).to(self.device)
+
+ return batch
+
+ def __len__(self):
+ return len(self.batch_indices)
+
+ def __getitem__(self, idx):
+ if self.cached_batches is not None:
+ return self.cached_batches[idx]
+ else:
+ batch_id = self.batch_indices[idx]
+ arrow_files = self.batch_files_map[batch_id]
+ return self._load_batch(arrow_files)
diff --git a/scripts/tiger-lsvd/lsvd_train_4.1_plum.py b/scripts/tiger-lsvd/lsvd_train_4.1_plum.py
new file mode 100644
index 0000000..96a08d7
--- /dev/null
+++ b/scripts/tiger-lsvd/lsvd_train_4.1_plum.py
@@ -0,0 +1,230 @@
+from functools import partial
+import json
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+from data import ArrowBatchDataset
+from models import TigerModel, CorrectItemsLogitsProcessor
+
+
+# ПУТИ
+IREC_PATH = '../../'
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/4-1_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e_35_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/eval_batches/')
+
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints')
+
+EXPERIMENT_NAME = 'tiger_4-1_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e_35'
+
+# ОСТАЛЬНОЕ
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+NUM_EPOCHS = 300
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 128
+CODEBOOK_SIZE = 512
+NUM_POSITIONS = 80
+NUM_USER_HASH = 2000
+NUM_HEADS = 6
+NUM_LAYERS = 4
+FEEDFORWARD_DIM = 1024
+KV_DIM = 64
+DROPOUT = 0.2
+NUM_BEAMS = 30
+TOP_K = 20
+NUM_CODEBOOKS = 4
+LR = 0.0001
+
+USE_MICROBATCHING = True
+MICROBATCH_SIZE = 256
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+import torch._inductor.config as config
+config.triton.cudagraph_skip_dynamic_graphs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=True
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = TigerModel(
+ embedding_dim=EMBEDDING_DIM,
+ codebook_size=CODEBOOK_SIZE,
+ sem_id_len=NUM_CODEBOOKS,
+ user_ids_count=NUM_USER_HASH,
+ num_positions=NUM_POSITIONS,
+ num_heads=NUM_HEADS,
+ num_encoder_layers=NUM_LAYERS,
+ num_decoder_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ num_beams=NUM_BEAMS,
+ num_return_sequences=TOP_K,
+ activation='relu',
+ d_kv=KV_DIM,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ logits_processor=partial(
+ CorrectItemsLogitsProcessor,
+ NUM_CODEBOOKS,
+ CODEBOOK_SIZE,
+ mappings,
+ NUM_BEAMS
+ ),
+ use_microbatching=USE_MICROBATCHING,
+ microbatch_size=MICROBATCH_SIZE
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _:{
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS * 4),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='validation/ndcg@20',
+ patience=40 * 4,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/lsvd_train_4.2_plum.py b/scripts/tiger-lsvd/lsvd_train_4.2_plum.py
new file mode 100644
index 0000000..991f662
--- /dev/null
+++ b/scripts/tiger-lsvd/lsvd_train_4.2_plum.py
@@ -0,0 +1,230 @@
+from functools import partial
+import json
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+from data import ArrowBatchDataset
+from models import TigerModel, CorrectItemsLogitsProcessor
+
+
+# ПУТИ
+IREC_PATH = '../../'
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/4-2_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e_35_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/eval_batches/')
+
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints')
+
+EXPERIMENT_NAME = 'tiger_4-2_vk_lsvd_ods_base_cb_512_ws_2_k_2000_8w_e_35'
+
+# ОСТАЛЬНОЕ
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+NUM_EPOCHS = 300
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 128
+CODEBOOK_SIZE = 512
+NUM_POSITIONS = 80
+NUM_USER_HASH = 2000
+NUM_HEADS = 6
+NUM_LAYERS = 4
+FEEDFORWARD_DIM = 1024
+KV_DIM = 64
+DROPOUT = 0.2
+NUM_BEAMS = 30
+TOP_K = 20
+NUM_CODEBOOKS = 4
+LR = 0.0001
+
+USE_MICROBATCHING = True
+MICROBATCH_SIZE = 256
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+import torch._inductor.config as config
+config.triton.cudagraph_skip_dynamic_graphs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=True
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = TigerModel(
+ embedding_dim=EMBEDDING_DIM,
+ codebook_size=CODEBOOK_SIZE,
+ sem_id_len=NUM_CODEBOOKS,
+ user_ids_count=NUM_USER_HASH,
+ num_positions=NUM_POSITIONS,
+ num_heads=NUM_HEADS,
+ num_encoder_layers=NUM_LAYERS,
+ num_decoder_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ num_beams=NUM_BEAMS,
+ num_return_sequences=TOP_K,
+ activation='relu',
+ d_kv=KV_DIM,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ logits_processor=partial(
+ CorrectItemsLogitsProcessor,
+ NUM_CODEBOOKS,
+ CODEBOOK_SIZE,
+ mappings,
+ NUM_BEAMS
+ ),
+ use_microbatching=USE_MICROBATCHING,
+ microbatch_size=MICROBATCH_SIZE
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _:{
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS * 4),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='validation/ndcg@20',
+ patience=40 * 4,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/lsvd_train_rqvae.py b/scripts/tiger-lsvd/lsvd_train_rqvae.py
new file mode 100644
index 0000000..aadf225
--- /dev/null
+++ b/scripts/tiger-lsvd/lsvd_train_rqvae.py
@@ -0,0 +1,230 @@
+from functools import partial
+import json
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+from data import ArrowBatchDataset
+from models import TigerModel, CorrectItemsLogitsProcessor
+
+
+# ПУТИ
+IREC_PATH = '../../'
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/rqvae_vk_lsvd_cz_512_8-weeks_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/eval_batches/')
+
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints')
+
+EXPERIMENT_NAME = 'tiger_rqvae_vk_lsvd_cz_512_8-weeks'
+
+# ОСТАЛЬНОЕ
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+NUM_EPOCHS = 300
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 128
+CODEBOOK_SIZE = 512
+NUM_POSITIONS = 80
+NUM_USER_HASH = 2000
+NUM_HEADS = 6
+NUM_LAYERS = 4
+FEEDFORWARD_DIM = 1024
+KV_DIM = 64
+DROPOUT = 0.2
+NUM_BEAMS = 30
+TOP_K = 20
+NUM_CODEBOOKS = 4
+LR = 0.0001
+
+USE_MICROBATCHING = True
+MICROBATCH_SIZE = 256
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+import torch._inductor.config as config
+config.triton.cudagraph_skip_dynamic_graphs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=True
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = TigerModel(
+ embedding_dim=EMBEDDING_DIM,
+ codebook_size=CODEBOOK_SIZE,
+ sem_id_len=NUM_CODEBOOKS,
+ user_ids_count=NUM_USER_HASH,
+ num_positions=NUM_POSITIONS,
+ num_heads=NUM_HEADS,
+ num_encoder_layers=NUM_LAYERS,
+ num_decoder_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ num_beams=NUM_BEAMS,
+ num_return_sequences=TOP_K,
+ activation='relu',
+ d_kv=KV_DIM,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ logits_processor=partial(
+ CorrectItemsLogitsProcessor,
+ NUM_CODEBOOKS,
+ CODEBOOK_SIZE,
+ mappings,
+ NUM_BEAMS
+ ),
+ use_microbatching=USE_MICROBATCHING,
+ microbatch_size=MICROBATCH_SIZE
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _:{
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS * 4),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='validation/ndcg@20',
+ patience=40 * 4,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/lsvd_varka_4.1_plum.py b/scripts/tiger-lsvd/lsvd_varka_4.1_plum.py
new file mode 100644
index 0000000..cc35507
--- /dev/null
+++ b/scripts/tiger-lsvd/lsvd_varka_4.1_plum.py
@@ -0,0 +1,304 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+from data import Dataset
+
+print("tiger no arrow varka 4.1")
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/base_with_gap_interactions_grouped.parquet"
+INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/val_interactions_grouped.parquet"
+INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/test_interactions_grouped.parquet"
+
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/4-1_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e_35_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.1/eval_batches/')
+
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_TRAIN_EVENTS = 500
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 2000
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ max_item_id = max(int(k) for k in mapping.keys())
+ print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys()))
+ # print(mapping["280052"]) #304781
+ # assert False
+ data = []
+ for i in range(max_item_id + 1):
+ if str(i) in mapping:
+ data.append(mapping[str(i)])
+ else:
+ data.append([-1] * NUM_CODEBOOKS)
+
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ missing_count = (max_item_id + 1) - len(mapping)
+ print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)")
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ semantic_ids = self._mapping_tensor[ids].flatten()
+
+ assert (semantic_ids != -1).all(), \
+ f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}"
+
+ batch[f'{name}.semantic.ids'] = semantic_ids.numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+ print("варка может начать умирать")
+ data = Dataset.create_timestamp_based_parquet(
+ train_parquet_path=INTERACTIONS_TRAIN_PATH,
+ validation_parquet_path=INTERACTIONS_VALID_PATH,
+ test_parquet_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True,
+ max_train_events=MAX_TRAIN_EVENTS
+ )
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+ print("варка не умерла")
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/lsvd_varka_4.2_plum.py b/scripts/tiger-lsvd/lsvd_varka_4.2_plum.py
new file mode 100644
index 0000000..7de54e4
--- /dev/null
+++ b/scripts/tiger-lsvd/lsvd_varka_4.2_plum.py
@@ -0,0 +1,304 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+from data import Dataset
+
+print("tiger no arrow varka 4.1")
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/base_with_gap_interactions_grouped.parquet"
+INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/val_interactions_grouped.parquet"
+INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-days-base-ows/test_interactions_grouped.parquet"
+
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/4-2_vk_lsvd_ods_base_with_gap_cb_512_ws_2_k_2000_8w_e_35_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-4.2/eval_batches/')
+
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_TRAIN_EVENTS = 500
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 2000
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ max_item_id = max(int(k) for k in mapping.keys())
+ print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys()))
+ # print(mapping["280052"]) #304781
+ # assert False
+ data = []
+ for i in range(max_item_id + 1):
+ if str(i) in mapping:
+ data.append(mapping[str(i)])
+ else:
+ data.append([-1] * NUM_CODEBOOKS)
+
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ missing_count = (max_item_id + 1) - len(mapping)
+ print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)")
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ semantic_ids = self._mapping_tensor[ids].flatten()
+
+ assert (semantic_ids != -1).all(), \
+ f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}"
+
+ batch[f'{name}.semantic.ids'] = semantic_ids.numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+ print("варка может начать умирать")
+ data = Dataset.create_timestamp_based_parquet(
+ train_parquet_path=INTERACTIONS_TRAIN_PATH,
+ validation_parquet_path=INTERACTIONS_VALID_PATH,
+ test_parquet_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True,
+ max_train_events=MAX_TRAIN_EVENTS
+ )
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+ print("варка не умерла")
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/lsvd_varka_rqvae.py b/scripts/tiger-lsvd/lsvd_varka_rqvae.py
new file mode 100644
index 0000000..bb5ecc0
--- /dev/null
+++ b/scripts/tiger-lsvd/lsvd_varka_rqvae.py
@@ -0,0 +1,304 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+from data import Dataset
+
+print("tiger no arrow varka 4.1")
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_TRAIN_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/base_with_gap_interactions_grouped.parquet"
+INTERACTIONS_VALID_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/val_interactions_grouped.parquet"
+INTERACTIONS_TEST_PATH = "/home/jovyan/IRec/sigir/lsvd_data/8-weeks-base-ows/test_interactions_grouped.parquet"
+
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/rqvae_vk_lsvd_cz_512_8-weeks_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/lsvd/8-weeks-base-one-week-split-rqvae/eval_batches/')
+
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_TRAIN_EVENTS = 500
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 2000
+CODEBOOK_SIZE = 512
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ max_item_id = max(int(k) for k in mapping.keys())
+ print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys()))
+ # print(mapping["280052"]) #304781
+ # assert False
+ data = []
+ for i in range(max_item_id + 1):
+ if str(i) in mapping:
+ data.append(mapping[str(i)])
+ else:
+ data.append([-1] * NUM_CODEBOOKS)
+
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ missing_count = (max_item_id + 1) - len(mapping)
+ print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)")
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ semantic_ids = self._mapping_tensor[ids].flatten()
+
+ assert (semantic_ids != -1).all(), \
+ f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}"
+
+ batch[f'{name}.semantic.ids'] = semantic_ids.numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+ print("варка может начать умирать")
+ data = Dataset.create_timestamp_based_parquet(
+ train_parquet_path=INTERACTIONS_TRAIN_PATH,
+ validation_parquet_path=INTERACTIONS_VALID_PATH,
+ test_parquet_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True,
+ max_train_events=MAX_TRAIN_EVENTS
+ )
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+ print("варка не умерла")
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-lsvd/models.py b/scripts/tiger-lsvd/models.py
new file mode 100644
index 0000000..f89419e
--- /dev/null
+++ b/scripts/tiger-lsvd/models.py
@@ -0,0 +1,223 @@
+import torch
+from transformers import T5ForConditionalGeneration, T5Config, LogitsProcessor
+
+from irec.models import TorchModel
+
+
+class CorrectItemsLogitsProcessor(LogitsProcessor):
+ def __init__(self, num_codebooks, codebook_size, mapping, num_beams, visited_items):
+ self.num_codebooks = num_codebooks
+ self.codebook_size = codebook_size
+ self.num_beams = num_beams
+
+ semantic_ids = []
+ for i in range(len(mapping)):
+ assert len(mapping[str(i)]) == num_codebooks, 'All semantic ids must have the same length'
+ semantic_ids.append(mapping[str(i)])
+
+ self.index_semantic_ids = torch.tensor(semantic_ids, dtype=torch.long, device=visited_items.device) # (num_items, semantic_ids)
+
+ batch_size, _ = visited_items.shape
+
+ self.index_semantic_ids = torch.tile(self.index_semantic_ids[None], dims=[batch_size, 1, 1]) # (batch_size, num_items, semantic_ids)
+
+ index = visited_items[..., None].tile(dims=[1, 1, num_codebooks]) # (batch_size, num_rated, semantic_ids)
+ self.index_semantic_ids = torch.scatter(
+ input=self.index_semantic_ids,
+ dim=1,
+ index=index,
+ src=torch.zeros_like(index)
+ ) # (batch_size, num_items, semantic_ids)
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ next_sid_codebook_num = (torch.minimum((input_ids[:, -1].max() // self.codebook_size), torch.as_tensor(self.num_codebooks - 1)).item() + 1) % self.num_codebooks
+ a = torch.tile(self.index_semantic_ids[:, None, :, next_sid_codebook_num], dims=[1, self.num_beams, 1]) # (batch_size, num_beams, num_items)
+ a = a.reshape(a.shape[0] * a.shape[1], a.shape[2]) # (batch_size * num_beams, num_items)
+
+ if next_sid_codebook_num != 0:
+ b = torch.tile(self.index_semantic_ids[:, None :, :next_sid_codebook_num], dims=[1, self.num_beams, 1, 1]) # (batch_size, num_beams, num_items, sid_len)
+ b = b.reshape(b.shape[0] * b.shape[1], b.shape[2], b.shape[3]) # (batch_size * num_beams, num_items, sid_len)
+
+ current_prefixes = input_ids[:, -next_sid_codebook_num:] # (batch_size * num_beams, sid_len)
+ possible_next_items_mask = (
+ torch.eq(current_prefixes[:, None, :], b).long().sum(dim=-1) == next_sid_codebook_num
+ ) # (batch_size * num_beams, num_items)
+ a[~possible_next_items_mask] = (next_sid_codebook_num + 1) * self.codebook_size
+
+ scores_mask = torch.zeros_like(scores).bool() # (batch_size * num_beams, num_items)
+ scores_mask = torch.scatter_add(
+ input=scores_mask,
+ dim=-1,
+ index=a,
+ src=torch.ones_like(a).bool()
+ )
+
+ scores[:, :next_sid_codebook_num * self.codebook_size] = -torch.inf
+ scores[:, (next_sid_codebook_num + 1) * self.codebook_size:] = -torch.inf
+ scores[~(scores_mask.bool())] = -torch.inf
+
+ return scores
+
+
+class TigerModel(TorchModel):
+ def __init__(
+ self,
+ embedding_dim,
+ codebook_size,
+ sem_id_len,
+ num_positions,
+ user_ids_count,
+ num_heads,
+ num_encoder_layers,
+ num_decoder_layers,
+ dim_feedforward,
+ num_beams=100,
+ num_return_sequences=20,
+ d_kv=64,
+ layer_norm_eps=1e-6,
+ activation='relu',
+ dropout=0.1,
+ initializer_range=0.02,
+ logits_processor=None,
+ use_microbatching=False,
+ microbatch_size=128
+ ):
+ super().__init__()
+ self._embedding_dim = embedding_dim
+ self._codebook_size = codebook_size
+ self._num_positions = num_positions
+ self._num_heads = num_heads
+ self._num_encoder_layers = num_encoder_layers
+ self._num_decoder_layers = num_decoder_layers
+ self._dim_feedforward = dim_feedforward
+ self._num_beams = num_beams
+ self._num_return_sequences = num_return_sequences
+ self._d_kv = d_kv
+ self._layer_norm_eps = layer_norm_eps
+ self._activation = activation
+ self._dropout = dropout
+ self._sem_id_len = sem_id_len
+ self.user_ids_count = user_ids_count
+ self.logits_processor = logits_processor
+ self._use_microbatching = use_microbatching
+ self._microbatch_size = microbatch_size
+
+ unified_vocab_size = codebook_size * self._sem_id_len + self.user_ids_count + 10 # 10 for utilities
+ self.config = T5Config(
+ vocab_size=unified_vocab_size,
+ d_model=self._embedding_dim,
+ d_kv=self._d_kv,
+ d_ff=self._dim_feedforward,
+ num_layers=self._num_encoder_layers,
+ num_decoder_layers=self._num_decoder_layers,
+ num_heads=self._num_heads,
+ dropout_rate=self._dropout,
+ is_encoder_decoder=True,
+ use_cache=False,
+ pad_token_id=unified_vocab_size - 1,
+ eos_token_id=unified_vocab_size - 2,
+ decoder_start_token_id=unified_vocab_size - 3,
+ layer_norm_epsilon=self._layer_norm_eps,
+ feed_forward_proj=self._activation,
+ tie_word_embeddings=False
+ )
+ self.model = T5ForConditionalGeneration(config=self.config)
+ self._init_weights(initializer_range)
+
+ self.model = torch.compile(
+ self.model,
+ mode='reduce-overhead',
+ fullgraph=False,
+ dynamic=True
+ )
+
+ def forward(self, inputs):
+ input_semantic_ids = inputs['input.data']
+ attention_mask = inputs['input.mask']
+ target_semantic_ids = inputs['output.data']
+
+ decoder_input_ids = target_semantic_ids[:, :-1].contiguous()
+ labels = target_semantic_ids[:, 1:].contiguous()
+
+ model_output = self.model(
+ input_ids=input_semantic_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ labels=labels
+ )
+ loss = model_output['loss']
+
+ metrics = {'loss': loss.detach()}
+
+ if not self.training and not self._use_microbatching:
+ visited_batch = inputs['visited.padded']
+
+ output = self.model.generate(
+ input_ids=input_semantic_ids,
+ attention_mask=attention_mask,
+ num_beams=self._num_beams,
+ num_return_sequences=self._num_return_sequences,
+ max_length=self._sem_id_len + 1,
+ decoder_start_token_id=self.config.decoder_start_token_id,
+ eos_token_id=self.config.eos_token_id,
+ pad_token_id=self.config.pad_token_id,
+ do_sample=False,
+ early_stopping=False,
+ logits_processor=[self.logits_processor(visited_items=visited_batch)] if self.logits_processor is not None else [],
+ )
+
+ predictions = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len)
+
+ all_hits = (torch.eq(predictions, labels[:, None]).sum(dim=-1)) # (batch_size, top_k)
+ elif not self.training and self._use_microbatching:
+ visited_batch = inputs['visited.padded']
+ batch_size = input_semantic_ids.shape[0]
+
+ inference_batch_size = self._microbatch_size # вместо полного batch_size
+
+ all_predictions = []
+ all_labels = []
+ # print(f"start to infer batch of shape {input_semantic_ids.shape} with new batch {inference_batch_size}")
+ for batch_idx in range(0, batch_size, inference_batch_size):
+ batch_end = min(batch_idx + inference_batch_size, batch_size)
+ batch_slice = slice(batch_idx, batch_end)
+
+ input_ids_batch = input_semantic_ids[batch_slice]
+ attention_mask_batch = attention_mask[batch_slice]
+ visited_batch_subset = visited_batch[batch_slice]
+ labels_batch = labels[batch_slice]
+
+ with torch.inference_mode():
+ output = self.model.generate(
+ input_ids=input_ids_batch,
+ attention_mask=attention_mask_batch,
+ num_beams=self._num_beams,
+ num_return_sequences=self._num_return_sequences,
+ max_length=self._sem_id_len + 1,
+ decoder_start_token_id=self.config.decoder_start_token_id,
+ eos_token_id=self.config.eos_token_id,
+ pad_token_id=self.config.pad_token_id,
+ do_sample=False,
+ early_stopping=False,
+ logits_processor=[self.logits_processor(visited_items=visited_batch_subset)] if self.logits_processor is not None else [],
+ )
+
+ predictions_batch = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len)
+ all_predictions.append(predictions_batch)
+ all_labels.append(labels_batch)
+ # print("end infer of batch")
+
+ predictions = torch.cat(all_predictions, dim=0) # (batch_size, num_return_sequences, sem_id_len)
+ labels_full = torch.cat(all_labels, dim=0) # (batch_size, sem_id_len)
+ all_hits = (torch.eq(predictions, labels_full[:, None]).sum(dim=-1)) # (batch_size, top_k)
+
+ if not self.training:
+ for k in [5, 10, 20]:
+ hits = (all_hits[:, :k] == self._sem_id_len).float() # (batch_size, k)
+ recall = hits.sum(dim=-1) # (batch_size)
+ discount_factor = 1 / torch.log2(torch.arange(1, k + 1, 1).float() + 1.).to(hits.device) # (k)
+
+ metrics[f'recall@{k}'] = recall.cpu().float()
+ metrics[f'ndcg@{k}'] = torch.einsum('bk,k->b', hits, discount_factor).cpu().float()
+
+ return loss, metrics
\ No newline at end of file
diff --git a/scripts/tiger-yambda/data.py b/scripts/tiger-yambda/data.py
new file mode 100644
index 0000000..87ff07d
--- /dev/null
+++ b/scripts/tiger-yambda/data.py
@@ -0,0 +1,498 @@
+from collections import defaultdict
+import json
+from loguru import logger
+import numpy as np
+from pathlib import Path
+
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.base import BaseDataset
+
+
+class Dataset:
+ def __init__(
+ self,
+ train_sampler,
+ validation_sampler,
+ test_sampler,
+ num_items,
+ max_sequence_length
+ ):
+ self._train_sampler = train_sampler
+ self._validation_sampler = validation_sampler
+ self._test_sampler = test_sampler
+ self._num_items = num_items
+ self._max_sequence_length = max_sequence_length
+
+ @classmethod
+ def create_timestamp_based(
+ cls,
+ train_json_path,
+ validation_json_path,
+ test_json_path,
+ max_sequence_length,
+ sampler_type,
+ min_sample_len=2,
+ is_extended=False,
+ max_train_events=50
+ ):
+ max_item_id = 0
+ train_dataset, validation_dataset, test_dataset = [], [], []
+ print("started to load datasets")
+ with open(train_json_path, 'r') as f:
+ train_data = json.load(f)
+ with open(validation_json_path, 'r') as f:
+ validation_data = json.load(f)
+ with open(test_json_path, 'r') as f:
+ test_data = json.load(f)
+
+ all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys())
+ print(f"all users count: {len(all_users)}")
+ us_count = 0
+ for user_id_str in all_users:
+ if us_count % 100 == 0:
+ print(f"user id {us_count}/{len(all_users)}: {user_id_str}")
+ user_id = int(user_id_str)
+
+ train_items = train_data.get(user_id_str, [])
+ validation_items = validation_data.get(user_id_str, [])
+ test_items = test_data.get(user_id_str, [])
+
+ full_sequence = train_items + validation_items + test_items
+ if full_sequence:
+ max_item_id = max(max_item_id, max(full_sequence))
+
+ if us_count % 100 == 0:
+ print(f"full sequence len: {len(full_sequence)}")
+ us_count += 1
+ assert len(full_sequence) >= 2, f'Core-5 dataset is used, user {user_id} has only {len(full_sequence)} items'
+
+ # Обрезаем train на последние max_train_events событий
+ train_items = train_items[-max_train_events:] if len(train_items) > max_train_events else train_items
+
+ if is_extended:
+ # sample = [1, 2]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ # sample = [1, 2, 3, 4, 5, 6, 7]
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ for prefix_length in range(min_sample_len, len(train_items) + 1):
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': train_items[:prefix_length],
+ })
+ else:
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': train_items,
+ })
+
+ # валидация
+
+ # разворачиваем каждый айтем из валидации в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+
+ current_history = train_items.copy()
+ valid_small_history = 0
+ for item in validation_items:
+ # эвал датасет сам отрезает таргет потом
+ sample_sequence = current_history + [item]
+
+ if len(sample_sequence) >= min_sample_len:
+ validation_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': sample_sequence,
+ })
+ else:
+ valid_small_history += 1
+ current_history.append(item)
+
+ # разворачиваем каждый айтем из теста в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ current_history = train_items + validation_items
+ test_small_history = 0
+ for item in test_items:
+ sample_sequence = current_history + [item]
+ if len(sample_sequence) >= min_sample_len:
+ test_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': sample_sequence,
+ })
+ else:
+ test_small_history += 1
+ current_history.append(item)
+
+ print(f"Train dataset size: {len(train_dataset)}")
+ print(f"Validation dataset size: {len(validation_dataset)} with skipped {valid_small_history}")
+ print(f"Test dataset size: {len(test_dataset)} with skipped {test_small_history}")
+
+ logger.debug(f'Train dataset size: {len(train_dataset)}')
+ logger.debug(f'Validation dataset size: {len(validation_dataset)}')
+ logger.debug(f'Test dataset size: {len(test_dataset)}')
+
+ train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length)
+ validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length)
+ test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length)
+
+ return cls(
+ train_sampler=train_sampler,
+ validation_sampler=validation_sampler,
+ test_sampler=test_sampler,
+ num_items=max_item_id + 1, # +1 added because our ids are 0-indexed
+ max_sequence_length=max_sequence_length
+ )
+
+ @classmethod
+ def create_timestamp_based_with_one_valid(
+ cls,
+ train_json_path,
+ validation_json_path,
+ test_json_path,
+ max_sequence_length,
+ sampler_type,
+ min_sample_len=2,
+ is_extended=False,
+ max_train_events=50,
+ max_valid_events=50
+ ):
+ max_item_id = 0
+ train_dataset, validation_dataset, test_dataset = [], [], []
+ print("started to load datasets")
+ with open(train_json_path, 'r') as f:
+ train_data = json.load(f)
+ with open(validation_json_path, 'r') as f:
+ validation_data = json.load(f)
+ with open(test_json_path, 'r') as f:
+ test_data = json.load(f)
+
+ all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys())
+ print(f"all users count: {len(all_users)}")
+ us_count = 0
+ for user_id_str in all_users:
+ if us_count % 100 == 0:
+ print(f"user id {us_count}/{len(all_users)}: {user_id_str}")
+ user_id = int(user_id_str)
+
+ train_items = train_data.get(user_id_str, [])
+ validation_items = validation_data.get(user_id_str, [])
+ test_items = test_data.get(user_id_str, [])
+
+ full_sequence = train_items + validation_items + test_items
+ if full_sequence:
+ max_item_id = max(max_item_id, max(full_sequence))
+
+ if us_count % 100 == 0:
+ print(f"full sequence len: {len(full_sequence)}")
+
+ assert len(full_sequence) >= 2, f'Core-5 dataset is used, user {user_id} has only {len(full_sequence)} items'
+
+ # Обрезаем train на последние max_train_events событий
+ train_items = train_items[-max_train_events:] if len(train_items) > max_train_events else train_items
+
+ if is_extended:
+ # sample = [1, 2]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ # sample = [1, 2, 3, 4, 5, 6, 7]
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ for prefix_length in range(min_sample_len, len(train_items) + 1):
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': train_items[:prefix_length],
+ })
+ else:
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': train_items,
+ })
+
+ # валидация
+
+ # разворачиваем каждый айтем из валидации в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+
+ current_history = train_items.copy()
+ if us_count % 100 == 0:
+ print(f"validation data length {len(validation_items[:max_valid_events])}")
+ us_count += 1
+ for item in validation_items[:max_valid_events]:
+ # эвал датасет сам отрезает таргет потом
+ sample_sequence = current_history + [item]
+
+ if len(sample_sequence) >= min_sample_len:
+ validation_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': sample_sequence,
+ })
+ current_history.append(item)
+
+ # разворачиваем каждый айтем из теста в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ current_history = train_items + validation_items
+ for item in test_items:
+ sample_sequence = current_history + [item]
+ if len(sample_sequence) >= min_sample_len:
+ test_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': sample_sequence,
+ })
+ current_history.append(item)
+
+ print(f"Train dataset size: {len(train_dataset)}")
+ print(f"Validation dataset size: {len(validation_dataset)}")
+ print(f"Test dataset size: {len(test_dataset)}")
+
+ logger.debug(f'Train dataset size: {len(train_dataset)}')
+ logger.debug(f'Validation dataset size: {len(validation_dataset)}')
+ logger.debug(f'Test dataset size: {len(test_dataset)}')
+
+ train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length)
+ validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length)
+ test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length)
+
+ return cls(
+ train_sampler=train_sampler,
+ validation_sampler=validation_sampler,
+ test_sampler=test_sampler,
+ num_items=max_item_id + 1, # +1 added because our ids are 0-indexed
+ max_sequence_length=max_sequence_length
+ )
+
+ @classmethod
+ def create(cls, inter_json_path, max_sequence_length, sampler_type, is_extended=False):
+ max_item_id = 0
+ train_dataset, validation_dataset, test_dataset = [], [], []
+
+ with open(inter_json_path, 'r') as f:
+ user_interactions = json.load(f)
+
+ for user_id_str, item_ids in user_interactions.items():
+ user_id = int(user_id_str)
+
+ if item_ids:
+ max_item_id = max(max_item_id, max(item_ids))
+
+ assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items'
+
+ # sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] (leave one out scheme, 8 - train, 9 - valid, 10 - test)
+ if is_extended:
+ # sample = [1, 2]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ # sample = [1, 2, 3, 4, 5, 6, 7]
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ for prefix_length in range(2, len(item_ids) - 2 + 1):
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids[:prefix_length],
+ })
+ else:
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids[:-2],
+ })
+
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9]
+ validation_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids[:-1],
+ })
+
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+ test_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids,
+ })
+
+ logger.debug(f'Train dataset size: {len(train_dataset)}')
+ logger.debug(f'Validation dataset size: {len(validation_dataset)}')
+ logger.debug(f'Test dataset size: {len(test_dataset)}')
+ logger.debug(f'Max item id: {max_item_id}')
+
+ train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length)
+ validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length)
+ test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length)
+
+ return cls(
+ train_sampler=train_sampler,
+ validation_sampler=validation_sampler,
+ test_sampler=test_sampler,
+ num_items=max_item_id + 1, # +1 added because our ids are 0-indexed
+ max_sequence_length=max_sequence_length
+ )
+
+ def get_datasets(self):
+ return self._train_sampler, self._validation_sampler, self._test_sampler
+
+ @property
+ def num_items(self):
+ return self._num_items
+
+ @property
+ def max_sequence_length(self):
+ return self._max_sequence_length
+
+
+class TrainDataset(BaseDataset):
+ def __init__(self, dataset, prediction_type, max_sequence_length):
+ self._dataset = dataset
+ self._prediction_type = prediction_type
+ self._max_sequence_length = max_sequence_length
+
+ self._transforms = {
+ 'sasrec': self._all_items_transform,
+ 'tiger': self._last_item_transform
+ }
+
+ def _all_items_transform(self, sample):
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ next_item_sequence = sample['item.ids'][-self._max_sequence_length:][1:]
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array(next_item_sequence, dtype=np.int64),
+ 'labels.length': np.array([len(next_item_sequence)], dtype=np.int64)
+ }
+
+ def _last_item_transform(self, sample):
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ last_item = sample['item.ids'][-self._max_sequence_length:][-1]
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array([last_item], dtype=np.int64),
+ 'labels.length': np.array([1], dtype=np.int64),
+ }
+
+ def __getitem__(self, index):
+ return self._transforms[self._prediction_type](self._dataset[index])
+
+ def __len__(self):
+ return len(self._dataset)
+
+
+class EvalDataset(BaseDataset):
+ def __init__(self, dataset, max_sequence_length):
+ self._dataset = dataset
+ self._max_sequence_length = max_sequence_length
+
+ @property
+ def dataset(self):
+ return self._dataset
+
+ def __len__(self):
+ return len(self._dataset)
+
+ def __getitem__(self, index):
+ sample = self._dataset[index]
+
+ item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1]
+ next_item = sample['item.ids'][-self._max_sequence_length:][-1]
+
+ return {
+ 'user.ids': np.array(sample['user.ids'], dtype=np.int64),
+ 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64),
+ 'item.ids': np.array(item_sequence, dtype=np.int64),
+ 'item.length': np.array([len(item_sequence)], dtype=np.int64),
+ 'labels.ids': np.array([next_item], dtype=np.int64),
+ 'labels.length': np.array([1], dtype=np.int64),
+ 'visited.ids': np.array(sample['item.ids'][:-1], dtype=np.int64),
+ 'visited.length': np.array([len(sample['item.ids'][:-1])], dtype=np.int64),
+ }
+
+
+class ArrowBatchDataset(BaseDataset):
+ def __init__(self, batch_dir, device='cuda', preload=False):
+ self.batch_dir = Path(batch_dir)
+ self.device = device
+
+ all_files = list(self.batch_dir.glob('batch_*_len_*.arrow'))
+
+ batch_files_map = defaultdict(list)
+ for f in all_files:
+ batch_id = int(f.stem.split('_')[1])
+ batch_files_map[batch_id].append(f)
+
+ for batch_id in batch_files_map:
+ batch_files_map[batch_id].sort()
+
+ self.batch_indices = sorted(batch_files_map.keys())
+
+ if preload:
+ print(f"Preloading {len(self.batch_indices)} batches...")
+ self.cached_batches = []
+
+ for idx in range(len(self.batch_indices)):
+ batch = self._load_batch(batch_files_map[self.batch_indices[idx]])
+ self.cached_batches.append(batch)
+ else:
+ self.cached_batches = None
+ self.batch_files_map = batch_files_map
+
+ def _load_batch(self, arrow_files):
+ batch = {}
+
+ for arrow_file in arrow_files:
+ table = feather.read_table(arrow_file)
+ metadata = table.schema.metadata or {}
+
+ for col_name in table.column_names:
+ col = table.column(col_name)
+
+ shape_key = f'{col_name}_shape'
+ dtype_key = f'{col_name}_dtype'
+
+ if shape_key.encode() in metadata:
+ shape = eval(metadata[shape_key.encode()].decode())
+ dtype = np.dtype(metadata[dtype_key.encode()].decode())
+
+ # Проверяем тип колонки
+ if pa.types.is_list(col.type) or pa.types.is_large_list(col.type):
+ arr = np.array(col.to_pylist(), dtype=dtype)
+ else:
+ arr = col.to_numpy().reshape(shape).astype(dtype)
+ else:
+ if pa.types.is_list(col.type) or pa.types.is_large_list(col.type):
+ arr = np.array(col.to_pylist())
+ else:
+ arr = col.to_numpy()
+
+ batch[col_name] = torch.from_numpy(arr.copy()).to(self.device)
+
+ return batch
+
+ def __len__(self):
+ return len(self.batch_indices)
+
+ def __getitem__(self, idx):
+ if self.cached_batches is not None:
+ return self.cached_batches[idx]
+ else:
+ batch_id = self.batch_indices[idx]
+ arrow_files = self.batch_files_map[batch_id]
+ return self._load_batch(arrow_files)
diff --git a/scripts/tiger-yambda/models.py b/scripts/tiger-yambda/models.py
new file mode 100644
index 0000000..8fd0f76
--- /dev/null
+++ b/scripts/tiger-yambda/models.py
@@ -0,0 +1,223 @@
+import torch
+from transformers import T5ForConditionalGeneration, T5Config, LogitsProcessor
+
+from irec.models import TorchModel
+
+
+class CorrectItemsLogitsProcessor(LogitsProcessor):
+ def __init__(self, num_codebooks, codebook_size, mapping, num_beams, visited_items):
+ self.num_codebooks = num_codebooks
+ self.codebook_size = codebook_size
+ self.num_beams = num_beams
+
+ semantic_ids = []
+ for i in range(len(mapping)):
+ assert len(mapping[str(i)]) == num_codebooks, 'All semantic ids must have the same length'
+ semantic_ids.append(mapping[str(i)])
+
+ self.index_semantic_ids = torch.tensor(semantic_ids, dtype=torch.long, device=visited_items.device) # (num_items, semantic_ids)
+
+ batch_size, _ = visited_items.shape
+
+ self.index_semantic_ids = torch.tile(self.index_semantic_ids[None], dims=[batch_size, 1, 1]) # (batch_size, num_items, semantic_ids)
+
+ index = visited_items[..., None].tile(dims=[1, 1, num_codebooks]) # (batch_size, num_rated, semantic_ids)
+ self.index_semantic_ids = torch.scatter(
+ input=self.index_semantic_ids,
+ dim=1,
+ index=index,
+ src=torch.zeros_like(index)
+ ) # (batch_size, num_items, semantic_ids)
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
+ next_sid_codebook_num = (torch.minimum((input_ids[:, -1].max() // self.codebook_size), torch.as_tensor(self.num_codebooks - 1)).item() + 1) % self.num_codebooks
+ a = torch.tile(self.index_semantic_ids[:, None, :, next_sid_codebook_num], dims=[1, self.num_beams, 1]) # (batch_size, num_beams, num_items)
+ a = a.reshape(a.shape[0] * a.shape[1], a.shape[2]) # (batch_size * num_beams, num_items)
+
+ if next_sid_codebook_num != 0:
+ b = torch.tile(self.index_semantic_ids[:, None :, :next_sid_codebook_num], dims=[1, self.num_beams, 1, 1]) # (batch_size, num_beams, num_items, sid_len)
+ b = b.reshape(b.shape[0] * b.shape[1], b.shape[2], b.shape[3]) # (batch_size * num_beams, num_items, sid_len)
+
+ current_prefixes = input_ids[:, -next_sid_codebook_num:] # (batch_size * num_beams, sid_len)
+ possible_next_items_mask = (
+ torch.eq(current_prefixes[:, None, :], b).long().sum(dim=-1) == next_sid_codebook_num
+ ) # (batch_size * num_beams, num_items)
+ a[~possible_next_items_mask] = (next_sid_codebook_num + 1) * self.codebook_size
+
+ scores_mask = torch.zeros_like(scores).bool() # (batch_size * num_beams, num_items)
+ scores_mask = torch.scatter_add(
+ input=scores_mask,
+ dim=-1,
+ index=a,
+ src=torch.ones_like(a).bool()
+ )
+
+ scores[:, :next_sid_codebook_num * self.codebook_size] = -torch.inf
+ scores[:, (next_sid_codebook_num + 1) * self.codebook_size:] = -torch.inf
+ scores[~(scores_mask.bool())] = -torch.inf
+
+ return scores
+
+
+class TigerModel(TorchModel):
+ def __init__(
+ self,
+ embedding_dim,
+ codebook_size,
+ sem_id_len,
+ num_positions,
+ user_ids_count,
+ num_heads,
+ num_encoder_layers,
+ num_decoder_layers,
+ dim_feedforward,
+ num_beams=100,
+ num_return_sequences=20,
+ d_kv=64,
+ layer_norm_eps=1e-6,
+ activation='relu',
+ dropout=0.1,
+ initializer_range=0.02,
+ logits_processor=None,
+ use_microbatching=False,
+ microbatch_size=128
+ ):
+ super().__init__()
+ self._embedding_dim = embedding_dim
+ self._codebook_size = codebook_size
+ self._num_positions = num_positions
+ self._num_heads = num_heads
+ self._num_encoder_layers = num_encoder_layers
+ self._num_decoder_layers = num_decoder_layers
+ self._dim_feedforward = dim_feedforward
+ self._num_beams = num_beams
+ self._num_return_sequences = num_return_sequences
+ self._d_kv = d_kv
+ self._layer_norm_eps = layer_norm_eps
+ self._activation = activation
+ self._dropout = dropout
+ self._sem_id_len = sem_id_len
+ self.user_ids_count = user_ids_count
+ self.logits_processor = logits_processor
+ self._use_microbatching = use_microbatching
+ self._microbatch_size = microbatch_size
+
+ unified_vocab_size = codebook_size * self._sem_id_len + self.user_ids_count + 10 # 10 for utilities
+ self.config = T5Config(
+ vocab_size=unified_vocab_size,
+ d_model=self._embedding_dim,
+ d_kv=self._d_kv,
+ d_ff=self._dim_feedforward,
+ num_layers=self._num_encoder_layers,
+ num_decoder_layers=self._num_decoder_layers,
+ num_heads=self._num_heads,
+ dropout_rate=self._dropout,
+ is_encoder_decoder=True,
+ use_cache=False,
+ pad_token_id=unified_vocab_size - 1,
+ eos_token_id=unified_vocab_size - 2,
+ decoder_start_token_id=unified_vocab_size - 3,
+ layer_norm_epsilon=self._layer_norm_eps,
+ feed_forward_proj=self._activation,
+ tie_word_embeddings=False
+ )
+ self.model = T5ForConditionalGeneration(config=self.config)
+ self._init_weights(initializer_range)
+
+ self.model = torch.compile(
+ self.model,
+ mode='reduce-overhead',
+ fullgraph=False,
+ dynamic=True
+ )
+
+ def forward(self, inputs):
+ input_semantic_ids = inputs['input.data']
+ attention_mask = inputs['input.mask']
+ target_semantic_ids = inputs['output.data']
+
+ decoder_input_ids = target_semantic_ids[:, :-1].contiguous()
+ labels = target_semantic_ids[:, 1:].contiguous()
+
+ model_output = self.model(
+ input_ids=input_semantic_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ labels=labels
+ )
+ loss = model_output['loss']
+
+ metrics = {'loss': loss.detach()}
+
+ if not self.training and not self._use_microbatching:
+ visited_batch = inputs['visited.padded']
+
+ output = self.model.generate(
+ input_ids=input_semantic_ids,
+ attention_mask=attention_mask,
+ num_beams=self._num_beams,
+ num_return_sequences=self._num_return_sequences,
+ max_length=self._sem_id_len + 1,
+ decoder_start_token_id=self.config.decoder_start_token_id,
+ eos_token_id=self.config.eos_token_id,
+ pad_token_id=self.config.pad_token_id,
+ do_sample=False,
+ early_stopping=False,
+ logits_processor=[self.logits_processor(visited_items=visited_batch)] if self.logits_processor is not None else [],
+ )
+
+ predictions = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len)
+
+ all_hits = (torch.eq(predictions, labels[:, None]).sum(dim=-1)) # (batch_size, top_k)
+ elif not self.training and self._use_microbatching:
+ visited_batch = inputs['visited.padded']
+ batch_size = input_semantic_ids.shape[0]
+
+ inference_batch_size = self._microbatch_size # вместо полного batch_size
+
+ all_predictions = []
+ all_labels = []
+ # print(f"start to infer batch of shape {input_semantic_ids.shape} with new batch {inference_batch_size}")
+ for batch_idx in range(0, batch_size, inference_batch_size):
+ batch_end = min(batch_idx + inference_batch_size, batch_size)
+ batch_slice = slice(batch_idx, batch_end)
+
+ input_ids_batch = input_semantic_ids[batch_slice]
+ attention_mask_batch = attention_mask[batch_slice]
+ visited_batch_subset = visited_batch[batch_slice]
+ labels_batch = labels[batch_slice]
+
+ with torch.inference_mode():
+ output = self.model.generate(
+ input_ids=input_ids_batch,
+ attention_mask=attention_mask_batch,
+ num_beams=self._num_beams,
+ num_return_sequences=self._num_return_sequences,
+ max_length=self._sem_id_len + 1,
+ decoder_start_token_id=self.config.decoder_start_token_id,
+ eos_token_id=self.config.eos_token_id,
+ pad_token_id=self.config.pad_token_id,
+ do_sample=False,
+ early_stopping=False,
+ logits_processor=[self.logits_processor(visited_items=visited_batch_subset)] if self.logits_processor is not None else [],
+ )
+
+ predictions_batch = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len)
+ all_predictions.append(predictions_batch)
+ all_labels.append(labels_batch)
+ # print("end infer of batch")
+
+ predictions = torch.cat(all_predictions, dim=0) # (batch_size, num_return_sequences, sem_id_len)
+ labels_full = torch.cat(all_labels, dim=0) # (batch_size, sem_id_len)
+ all_hits = (torch.eq(predictions, labels_full[:, None]).sum(dim=-1)) # (batch_size, top_k)
+
+ if not self.training:
+ for k in [5, 10, 20]:
+ hits = (all_hits[:, :k] == self._sem_id_len).float() # (batch_size, k)
+ recall = hits.sum(dim=-1) # (batch_size)
+ discount_factor = 1 / torch.log2(torch.arange(1, k + 1, 1).float() + 1.).to(hits.device) # (k)
+
+ metrics[f'recall@{k}'] = recall.cpu().float()
+ metrics[f'ndcg@{k}'] = torch.einsum('bk,k->b', hits, discount_factor).cpu().float()
+
+ return loss, metrics
\ No newline at end of file
diff --git a/scripts/tiger-yambda/yambda_train_4.1_plum.py b/scripts/tiger-yambda/yambda_train_4.1_plum.py
new file mode 100644
index 0000000..607c0e3
--- /dev/null
+++ b/scripts/tiger-yambda/yambda_train_4.1_plum.py
@@ -0,0 +1,230 @@
+from functools import partial
+import json
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+from data import ArrowBatchDataset
+from models import TigerModel, CorrectItemsLogitsProcessor
+
+
+# ПУТИ
+IREC_PATH = '../../'
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir_yambda/4-1_filtered_yambda_gpu_quantile_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_eval_batches/')
+
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints')
+
+EXPERIMENT_NAME = 'TEST_tiger_yambda_filtered_day-split_plum_ws_2_dp_0.2_max_300_256_1024'
+
+# ОСТАЛЬНОЕ
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+NUM_EPOCHS = 300
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 128
+CODEBOOK_SIZE = 256
+NUM_POSITIONS = 20
+NUM_USER_HASH = 2000
+NUM_HEADS = 6
+NUM_LAYERS = 4
+FEEDFORWARD_DIM = 1024
+KV_DIM = 64
+DROPOUT = 0.2
+NUM_BEAMS = 30
+TOP_K = 20
+NUM_CODEBOOKS = 4
+LR = 0.0001
+
+USE_MICROBATCHING = True
+MICROBATCH_SIZE = 128
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+import torch._inductor.config as config
+config.triton.cudagraph_skip_dynamic_graphs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=True
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = TigerModel(
+ embedding_dim=EMBEDDING_DIM,
+ codebook_size=CODEBOOK_SIZE,
+ sem_id_len=NUM_CODEBOOKS,
+ user_ids_count=NUM_USER_HASH,
+ num_positions=NUM_POSITIONS,
+ num_heads=NUM_HEADS,
+ num_encoder_layers=NUM_LAYERS,
+ num_decoder_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ num_beams=NUM_BEAMS,
+ num_return_sequences=TOP_K,
+ activation='relu',
+ d_kv=KV_DIM,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ logits_processor=partial(
+ CorrectItemsLogitsProcessor,
+ NUM_CODEBOOKS,
+ CODEBOOK_SIZE,
+ mappings,
+ NUM_BEAMS
+ ),
+ use_microbatching=USE_MICROBATCHING,
+ microbatch_size=MICROBATCH_SIZE
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _:{
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS * 4),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='validation/ndcg@20',
+ patience=40 * 4,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger-yambda/yambda_varka_4.1_plum.py b/scripts/tiger-yambda/yambda_varka_4.1_plum.py
new file mode 100644
index 0000000..9c00704
--- /dev/null
+++ b/scripts/tiger-yambda/yambda_varka_4.1_plum.py
@@ -0,0 +1,304 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+from data import Dataset
+
+print("tiger no arrow varka 4.1")
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'data/Yambda/day-splits/merged_for_exps_filtered/exp_4_0.9_inter_tiger_train.json')
+INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'data/Yambda/day-splits/merged_for_exps_filtered/valid_set.json')
+INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'data/Yambda/day-splits/merged_for_exps_filtered/test_set.json')
+
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir_yambda/4-1_filtered_yambda_gpu_quantile_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Yambda/day-splits/test/yambda_quantile_tiger_T_eval_batches/')
+
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_TRAIN_EVENTS = 300
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 2000
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ max_item_id = max(int(k) for k in mapping.keys())
+ print(len(list(mapping.keys())), min(int(k) for k in mapping.keys()) , max(int(k) for k in mapping.keys()))
+ print(mapping["280052"]) #304781
+ # assert False
+ data = []
+ for i in range(max_item_id + 1):
+ if str(i) in mapping:
+ data.append(mapping[str(i)])
+ else:
+ data.append([-1] * NUM_CODEBOOKS)
+
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ missing_count = (max_item_id + 1) - len(mapping)
+ print(f"Mapping: {len(mapping)} items, {missing_count} missing (-1 filled)")
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ semantic_ids = self._mapping_tensor[ids].flatten()
+
+ assert (semantic_ids != -1).all(), \
+ f"Missing mappings detected in {name}! Invalid positions: {(semantic_ids == -1).sum()} out of {len(semantic_ids)}"
+
+ batch[f'{name}.semantic.ids'] = semantic_ids.numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ data = Dataset.create_timestamp_based(
+ train_json_path=INTERACTIONS_TRAIN_PATH,
+ validation_json_path=INTERACTIONS_VALID_PATH,
+ test_json_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True,
+ max_train_events=MAX_TRAIN_EVENTS
+ )
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger/beauty_exps/train_4.1_plum.py b/scripts/tiger/beauty_exps/train_4.1_plum.py
new file mode 100644
index 0000000..8daf273
--- /dev/null
+++ b/scripts/tiger/beauty_exps/train_4.1_plum.py
@@ -0,0 +1,225 @@
+from functools import partial
+import json
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+from data import ArrowBatchDataset
+from models import TigerModel, CorrectItemsLogitsProcessor
+
+
+# ПУТИ
+IREC_PATH = '../../'
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_eval_batches/')
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints')
+
+EXPERIMENT_NAME = 'tiger_beauty_updated_quantile_4-1_plum_ws_2_dp_0.2'
+
+# ОСТАЛЬНОЕ
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+NUM_EPOCHS = 300
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 128
+CODEBOOK_SIZE = 256
+NUM_POSITIONS = 20
+NUM_USER_HASH = 2000
+NUM_HEADS = 6
+NUM_LAYERS = 4
+FEEDFORWARD_DIM = 1024
+KV_DIM = 64
+DROPOUT = 0.2
+NUM_BEAMS = 30
+TOP_K = 20
+NUM_CODEBOOKS = 4
+LR = 0.0001
+
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+import torch._inductor.config as config
+config.triton.cudagraph_skip_dynamic_graphs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=True
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = TigerModel(
+ embedding_dim=EMBEDDING_DIM,
+ codebook_size=CODEBOOK_SIZE,
+ sem_id_len=NUM_CODEBOOKS,
+ user_ids_count=NUM_USER_HASH,
+ num_positions=NUM_POSITIONS,
+ num_heads=NUM_HEADS,
+ num_encoder_layers=NUM_LAYERS,
+ num_decoder_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ num_beams=NUM_BEAMS,
+ num_return_sequences=TOP_K,
+ activation='relu',
+ d_kv=KV_DIM,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ logits_processor=partial(
+ CorrectItemsLogitsProcessor,
+ NUM_CODEBOOKS,
+ CODEBOOK_SIZE,
+ mappings,
+ NUM_BEAMS
+ )
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _:{
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='eval/ndcg@20',
+ patience=40,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger/beauty_exps/train_4.2_plum.py b/scripts/tiger/beauty_exps/train_4.2_plum.py
new file mode 100644
index 0000000..580bcb5
--- /dev/null
+++ b/scripts/tiger/beauty_exps/train_4.2_plum.py
@@ -0,0 +1,225 @@
+from functools import partial
+import json
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+from data import ArrowBatchDataset
+from models import TigerModel, CorrectItemsLogitsProcessor
+
+
+# ПУТИ
+IREC_PATH = '../../'
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-2_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_eval_batches/')
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints')
+
+EXPERIMENT_NAME = 'tiger_beauty_updated_quantile_4-2_plum_ws_2_dp_0.2'
+
+# ОСТАЛЬНОЕ
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+NUM_EPOCHS = 300
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 128
+CODEBOOK_SIZE = 256
+NUM_POSITIONS = 20
+NUM_USER_HASH = 2000
+NUM_HEADS = 6
+NUM_LAYERS = 4
+FEEDFORWARD_DIM = 1024
+KV_DIM = 64
+DROPOUT = 0.2
+NUM_BEAMS = 30
+TOP_K = 20
+NUM_CODEBOOKS = 4
+LR = 0.0001
+
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+import torch._inductor.config as config
+config.triton.cudagraph_skip_dynamic_graphs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=True
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = TigerModel(
+ embedding_dim=EMBEDDING_DIM,
+ codebook_size=CODEBOOK_SIZE,
+ sem_id_len=NUM_CODEBOOKS,
+ user_ids_count=NUM_USER_HASH,
+ num_positions=NUM_POSITIONS,
+ num_heads=NUM_HEADS,
+ num_encoder_layers=NUM_LAYERS,
+ num_decoder_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ num_beams=NUM_BEAMS,
+ num_return_sequences=TOP_K,
+ activation='relu',
+ d_kv=KV_DIM,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ logits_processor=partial(
+ CorrectItemsLogitsProcessor,
+ NUM_CODEBOOKS,
+ CODEBOOK_SIZE,
+ mappings,
+ NUM_BEAMS
+ )
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _:{
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='eval/ndcg@20',
+ patience=40,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger/beauty_exps/train_4.3_plum.py b/scripts/tiger/beauty_exps/train_4.3_plum.py
new file mode 100644
index 0000000..f98e9fd
--- /dev/null
+++ b/scripts/tiger/beauty_exps/train_4.3_plum.py
@@ -0,0 +1,225 @@
+from functools import partial
+import json
+from loguru import logger
+import os
+
+import torch
+
+import irec.callbacks as cb
+from irec.data.transforms import Collate, ToDevice
+from irec.data.dataloader import DataLoader
+from irec.runners import TrainingRunner
+from irec.utils import fix_random_seed
+
+from data import ArrowBatchDataset
+from models import TigerModel, CorrectItemsLogitsProcessor
+
+
+# ПУТИ
+IREC_PATH = '../../'
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-3_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_eval_batches/')
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints')
+
+EXPERIMENT_NAME = 'tiger_beauty_updated_quantile_4-3_plum_ws_2_dp_0.2'
+
+# ОСТАЛЬНОЕ
+SEED_VALUE = 42
+DEVICE = 'cuda'
+
+NUM_EPOCHS = 300
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+EMBEDDING_DIM = 128
+CODEBOOK_SIZE = 256
+NUM_POSITIONS = 20
+NUM_USER_HASH = 2000
+NUM_HEADS = 6
+NUM_LAYERS = 4
+FEEDFORWARD_DIM = 1024
+KV_DIM = 64
+DROPOUT = 0.2
+NUM_BEAMS = 30
+TOP_K = 20
+NUM_CODEBOOKS = 4
+LR = 0.0001
+
+
+torch.set_float32_matmul_precision('high')
+torch._dynamo.config.capture_scalar_outputs = True
+
+import torch._inductor.config as config
+config.triton.cudagraph_skip_dynamic_graphs = True
+
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataloader = DataLoader(
+ ArrowBatchDataset(
+ TRAIN_BATCHES_DIR,
+ device='cpu',
+ preload=True
+ ),
+ batch_size=1,
+ shuffle=True,
+ num_workers=0,
+ pin_memory=True,
+ collate_fn=Collate()
+ ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
+
+ valid_dataloder = ArrowBatchDataset(
+ VALID_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ eval_dataloder = ArrowBatchDataset(
+ EVAL_BATCHES_DIR,
+ device=DEVICE,
+ preload=True
+ )
+
+ model = TigerModel(
+ embedding_dim=EMBEDDING_DIM,
+ codebook_size=CODEBOOK_SIZE,
+ sem_id_len=NUM_CODEBOOKS,
+ user_ids_count=NUM_USER_HASH,
+ num_positions=NUM_POSITIONS,
+ num_heads=NUM_HEADS,
+ num_encoder_layers=NUM_LAYERS,
+ num_decoder_layers=NUM_LAYERS,
+ dim_feedforward=FEEDFORWARD_DIM,
+ num_beams=NUM_BEAMS,
+ num_return_sequences=TOP_K,
+ activation='relu',
+ d_kv=KV_DIM,
+ dropout=DROPOUT,
+ layer_norm_eps=1e-6,
+ initializer_range=0.02,
+ logits_processor=partial(
+ CorrectItemsLogitsProcessor,
+ NUM_CODEBOOKS,
+ CODEBOOK_SIZE,
+ mappings,
+ NUM_BEAMS
+ )
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=LR,
+ )
+
+ EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS)
+
+ callbacks = [
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ }, name='train'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=EPOCH_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _:{
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='validation'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'validation/loss': cb.MeanAccumulator(),
+ 'validation/recall@5': cb.MeanAccumulator(),
+ 'validation/recall@10': cb.MeanAccumulator(),
+ 'validation/recall@20': cb.MeanAccumulator(),
+ 'validation/ndcg@5': cb.MeanAccumulator(),
+ 'validation/ndcg@10': cb.MeanAccumulator(),
+ 'validation/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Validation(
+ dataset=eval_dataloder,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, _: {
+ 'loss': model_outputs['loss'].item(),
+ 'recall@5': model_outputs['recall@5'].tolist(),
+ 'recall@10': model_outputs['recall@10'].tolist(),
+ 'recall@20': model_outputs['recall@20'].tolist(),
+ 'ndcg@5': model_outputs['ndcg@5'].tolist(),
+ 'ndcg@10': model_outputs['ndcg@10'].tolist(),
+ 'ndcg@20': model_outputs['ndcg@20'].tolist(),
+ }, name='eval'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'eval/loss': cb.MeanAccumulator(),
+ 'eval/recall@5': cb.MeanAccumulator(),
+ 'eval/recall@10': cb.MeanAccumulator(),
+ 'eval/recall@20': cb.MeanAccumulator(),
+ 'eval/ndcg@5': cb.MeanAccumulator(),
+ 'eval/ndcg@10': cb.MeanAccumulator(),
+ 'eval/ndcg@20': cb.MeanAccumulator(),
+ },
+ ),
+ ],
+ ).every_num_steps(EPOCH_NUM_STEPS),
+
+ cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
+
+ cb.EarlyStopping(
+ metric='eval/ndcg@20',
+ patience=40,
+ minimize=False,
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
+ ).every_num_steps(EPOCH_NUM_STEPS)
+
+ # cb.Profiler(
+ # wait=10,
+ # warmup=10,
+ # active=10,
+ # logdir=TENSORBOARD_LOGDIR
+ # ),
+ # cb.StopAfterNumSteps(40)
+
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger/beauty_exps/varka_4.1_plum.py b/scripts/tiger/beauty_exps/varka_4.1_plum.py
new file mode 100644
index 0000000..302e04e
--- /dev/null
+++ b/scripts/tiger/beauty_exps/varka_4.1_plum.py
@@ -0,0 +1,287 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+from data import Dataset
+
+
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json')
+INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/valid_set.json')
+INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/test_set.json')
+
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-1_eval_batches/')
+
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 2000
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ data = []
+ for i in range(len(mapping)):
+ data.append(mapping[str(i)])
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ data = Dataset.create_timestamp_based(
+ train_json_path=INTERACTIONS_TRAIN_PATH,
+ validation_json_path=INTERACTIONS_VALID_PATH,
+ test_json_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True
+ )
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger/beauty_exps/varka_4.2_plum.py b/scripts/tiger/beauty_exps/varka_4.2_plum.py
new file mode 100644
index 0000000..b00fef2
--- /dev/null
+++ b/scripts/tiger/beauty_exps/varka_4.2_plum.py
@@ -0,0 +1,288 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+from data import Dataset
+
+
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json')
+INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/valid_set.json')
+INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/test_set.json')
+
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-2_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-2_eval_batches/')
+
+
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 2000
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ data = []
+ for i in range(len(mapping)):
+ data.append(mapping[str(i)])
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ data = Dataset.create_timestamp_based(
+ train_json_path=INTERACTIONS_TRAIN_PATH,
+ validation_json_path=INTERACTIONS_VALID_PATH,
+ test_json_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True
+ )
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger/beauty_exps/varka_4.3_plum.py b/scripts/tiger/beauty_exps/varka_4.3_plum.py
new file mode 100644
index 0000000..2a96339
--- /dev/null
+++ b/scripts/tiger/beauty_exps/varka_4.3_plum.py
@@ -0,0 +1,289 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+from data import Dataset
+
+
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json')
+INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/valid_set.json')
+INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/updated_quantile_splits/merged_for_exps/test_set.json')
+
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-3_updated_quantile_plum_rqvae_beauty_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/updated_quantile_tiger_4-3_eval_batches/')
+
+
+
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 2000
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ data = []
+ for i in range(len(mapping)):
+ data.append(mapping[str(i)])
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ data = Dataset.create_timestamp_based(
+ train_json_path=INTERACTIONS_TRAIN_PATH,
+ validation_json_path=INTERACTIONS_VALID_PATH,
+ test_json_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True
+ )
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger/data.py b/scripts/tiger/data.py
index 188993a..a34accd 100644
--- a/scripts/tiger/data.py
+++ b/scripts/tiger/data.py
@@ -28,6 +28,116 @@ def __init__(
self._num_items = num_items
self._max_sequence_length = max_sequence_length
+ @classmethod
+ def create_timestamp_based(
+ cls,
+ train_json_path,
+ validation_json_path,
+ test_json_path,
+ max_sequence_length,
+ sampler_type,
+ min_sample_len=2,
+ is_extended=False
+ ):
+ max_item_id = 0
+ train_dataset, validation_dataset, test_dataset = [], [], []
+
+ with open(train_json_path, 'r') as f:
+ train_data = json.load(f)
+ with open(validation_json_path, 'r') as f:
+ validation_data = json.load(f)
+ with open(test_json_path, 'r') as f:
+ test_data = json.load(f)
+
+ all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys())
+ print(f"all users count: {len(all_users)}")
+ for user_id_str in all_users:
+ user_id = int(user_id_str)
+
+ train_items = train_data.get(user_id_str, [])
+ validation_items = validation_data.get(user_id_str, [])
+ test_items = test_data.get(user_id_str, [])
+
+ full_sequence = train_items + validation_items + test_items
+ if full_sequence:
+ max_item_id = max(max_item_id, max(full_sequence))
+
+ assert len(full_sequence) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(full_sequence)} items'
+
+ if is_extended:
+ # sample = [1, 2]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ # sample = [1, 2, 3, 4, 5, 6, 7]
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ for prefix_length in range(min_sample_len, len(train_items) + 1):
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': train_items[:prefix_length],
+ })
+ else:
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': train_items,
+ })
+
+ # разворачиваем каждый айтем из валидации в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+
+ current_history = train_items.copy()
+ for item in validation_items:
+ # эвал датасет сам отрезает таргет потом
+ sample_sequence = current_history + [item]
+
+ if len(sample_sequence) >= min_sample_len:
+ validation_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': sample_sequence,
+ })
+ current_history.append(item)
+
+ # разворачиваем каждый айтем из теста в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ current_history = train_items + validation_items
+
+ for item in test_items:
+ # эвал датасет сам отрезает таргет потом
+ sample_sequence = current_history + [item]
+
+ if len(sample_sequence) >= min_sample_len:
+ test_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': sample_sequence,
+ })
+
+ current_history.append(item)
+
+ logger.debug(f'Train dataset size: {len(train_dataset)}')
+ logger.debug(f'Validation dataset size: {len(validation_dataset)}')
+ logger.debug(f'Test dataset size: {len(test_dataset)}')
+ print(f'Train dataset size: {len(train_dataset)}')
+ print(f'Validation dataset size: {len(validation_dataset)}')
+ print(f'Test dataset size: {len(test_dataset)}')
+
+ train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length)
+ validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length)
+ test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length)
+
+ return cls(
+ train_sampler=train_sampler,
+ validation_sampler=validation_sampler,
+ test_sampler=test_sampler,
+ num_items=max_item_id + 1, # +1 added because our ids are 0-indexed
+ max_sequence_length=max_sequence_length
+ )
+
@classmethod
def create(cls, inter_json_path, max_sequence_length, sampler_type, is_extended=False):
max_item_id = 0
diff --git a/scripts/tiger/train.py b/scripts/tiger/train.py
index f436dd4..1a2d347 100644
--- a/scripts/tiger/train.py
+++ b/scripts/tiger/train.py
@@ -14,10 +14,23 @@
from data import ArrowBatchDataset
from models import TigerModel, CorrectItemsLogitsProcessor
+
+# ПУТИ
+IREC_PATH = '../../'
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_plum_rqvae_beauty_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_eval_batches/')
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints')
+
+EXPERIMENT_NAME = 'tiger_beauty_4-1_plum_ws_2_dp_0.2'
+
+# ОСТАЛЬНОЕ
SEED_VALUE = 42
DEVICE = 'cuda'
-EXPERIMENT_NAME = 'tiger_beauty'
NUM_EPOCHS = 300
MAX_SEQ_LEN = 20
TRAIN_BATCH_SIZE = 256
@@ -30,13 +43,12 @@
NUM_LAYERS = 4
FEEDFORWARD_DIM = 1024
KV_DIM = 64
-DROPOUT = 0.1
+DROPOUT = 0.2
NUM_BEAMS = 30
TOP_K = 20
NUM_CODEBOOKS = 4
-LR = 3e-4
+LR = 0.0001
-IREC_PATH = '../../'
torch.set_float32_matmul_precision('high')
torch._dynamo.config.capture_scalar_outputs = True
@@ -48,30 +60,30 @@
def main():
fix_random_seed(SEED_VALUE)
- with open(os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json'), 'r') as f:
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
mappings = json.load(f)
-
+
train_dataloader = DataLoader(
ArrowBatchDataset(
- os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/'),
- device='cpu',
+ TRAIN_BATCHES_DIR,
+ device='cpu',
preload=True
),
- batch_size=1,
- shuffle=True,
+ batch_size=1,
+ shuffle=True,
num_workers=0,
- pin_memory=True,
+ pin_memory=True,
collate_fn=Collate()
).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
valid_dataloder = ArrowBatchDataset(
- os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/'),
+ VALID_BATCHES_DIR,
device=DEVICE,
preload=True
)
eval_dataloder = ArrowBatchDataset(
- os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/'),
+ EVAL_BATCHES_DIR,
device=DEVICE,
preload=True
)
@@ -177,22 +189,22 @@ def main():
),
],
).every_num_steps(EPOCH_NUM_STEPS),
-
+
cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
- cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
cb.EarlyStopping(
- metric='eval/ndcg@20',
+ metric='eval/ndcg@20',
patience=40,
minimize=False,
- model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
).every_num_steps(EPOCH_NUM_STEPS)
# cb.Profiler(
# wait=10,
# warmup=10,
# active=10,
- # logdir=os.path.join(IREC_PATH, 'tensorboard_logs')
+ # logdir=TENSORBOARD_LOGDIR
# ),
# cb.StopAfterNumSteps(40)
diff --git a/scripts/tiger/varka.py b/scripts/tiger/varka.py
index ed47595..4dc3e02 100644
--- a/scripts/tiger/varka.py
+++ b/scripts/tiger/varka.py
@@ -15,6 +15,20 @@
from data import Dataset
+
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_PATH = os.path.join(IREC_PATH, 'data/Beauty/inter.json')
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/')
+
+
+# ОСТАЛЬНОЕ
+
SEED_VALUE = 42
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
@@ -32,8 +46,6 @@
DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
-IREC_PATH = '../../'
-
class TigerProcessing(Transform):
def __call__(self, batch):
@@ -42,12 +54,12 @@ def __call__(self, batch):
input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
- input_semantic_ids = np.concat([
+ input_semantic_ids = np.concatenate([
input_semantic_ids,
NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
], axis=-1)
- attention_mask = np.concat([
+ attention_mask = np.concatenate([
attention_mask,
np.ones((batch_size, 1), dtype=attention_mask.dtype)
], axis=-1)
@@ -56,7 +68,7 @@ def __call__(self, batch):
batch['input.mask'] = attention_mask
target_semantic_ids = batch['labels.semantic.padded']
- target_semantic_ids = np.concat([
+ target_semantic_ids = np.concatenate([
np.ones(
(batch_size, 1),
dtype=np.int64,
@@ -73,7 +85,7 @@ class ToMasked(Transform):
def __init__(self, prefix, is_right_aligned=False):
self._prefix = prefix
self._is_right_aligned = is_right_aligned
-
+
def __call__(self, batch):
data = batch[f'{self._prefix}.ids']
lengths = batch[f'{self._prefix}.length']
@@ -92,7 +104,7 @@ def __call__(self, batch):
(batch_size, max_sequence_length, data.shape[-1]),
dtype=data.dtype
) # (batch_size, max_seq_len, emb_dim)
-
+
mask = np.arange(max_sequence_length)[None] < lengths[:, None]
if self._is_right_aligned:
@@ -117,10 +129,10 @@ def __init__(self, mapping, names=[]):
data.append(mapping[str(i)])
self._mapping_tensor = torch.tensor(data, dtype=torch.long)
self._semantic_length = self._mapping_tensor.shape[-1]
-
+
def __call__(self, batch):
for name in self._names:
- if f'{name}.ids' in batch:
+ if f'{name}.ids' in batch:
ids = batch[f'{name}.ids']
lengths = batch[f'{name}.length']
assert ids.min() >= 0
@@ -135,7 +147,7 @@ class UserHashing(Transform):
def __init__(self, hash_size):
super().__init__()
self._hash_size = hash_size
-
+
def __call__(self, batch):
batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
return batch
@@ -144,7 +156,7 @@ def __call__(self, batch):
def save_batches_to_arrow(batches, output_dir):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=False)
-
+
for batch_idx, batch in enumerate(batches):
length_groups = defaultdict(dict)
metadata_groups = defaultdict(dict)
@@ -164,7 +176,7 @@ def save_batches_to_arrow(batches, output_dir):
else:
# >2D массив - flatten и сохраняем shape
length_groups[length][key] = value.flatten()
-
+
for length, fields in length_groups.items():
arrow_dict = {}
for k, v in fields.items():
@@ -173,11 +185,11 @@ def save_batches_to_arrow(batches, output_dir):
arrow_dict[k] = pa.array(v)
else:
arrow_dict[k] = pa.array(v)
-
+
table = pa.table(arrow_dict)
if length in metadata_groups:
table = table.replace_schema_metadata(metadata_groups[length])
-
+
feather.write_feather(
table,
output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
@@ -186,7 +198,7 @@ def save_batches_to_arrow(batches, output_dir):
# arrow_dict = {k: pa.array(v) for k, v in fields.items()}
# table = pa.table(arrow_dict)
-
+
# feather.write_feather(
# table,
# output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
@@ -196,15 +208,15 @@ def save_batches_to_arrow(batches, output_dir):
def main():
data = Dataset.create(
- inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter.json'),
+ inter_json_path=INTERACTIONS_PATH,
max_sequence_length=MAX_SEQ_LEN,
sampler_type='tiger',
is_extended=True
)
- with open(os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json'), 'r') as f:
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
mappings = json.load(f)
-
+
train_dataset, valid_dataset, eval_dataset = data.get_datasets()
train_dataloader = DataLoader(
@@ -219,7 +231,7 @@ def main():
.map(ToMasked('item.semantic', is_right_aligned=True)) \
.map(ToMasked('labels.semantic', is_right_aligned=True)) \
.map(TigerProcessing())
-
+
valid_dataloader = DataLoader(
dataset=valid_dataset,
batch_size=VALID_BATCH_SIZE,
@@ -251,17 +263,18 @@ def main():
train_batches = []
for train_batch in train_dataloader:
train_batches.append(train_batch)
- save_batches_to_arrow(train_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/'))
-
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
valid_batches = []
for valid_batch in valid_dataloader:
valid_batches.append(valid_batch)
- save_batches_to_arrow(valid_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/'))
-
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
eval_batches = []
for eval_batch in eval_dataloader:
eval_batches.append(eval_batch)
- save_batches_to_arrow(eval_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/'))
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
if __name__ == '__main__':
diff --git a/scripts/tiger/varka_timestamp_based.py b/scripts/tiger/varka_timestamp_based.py
new file mode 100644
index 0000000..11343ea
--- /dev/null
+++ b/scripts/tiger/varka_timestamp_based.py
@@ -0,0 +1,287 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+from data import Dataset
+
+
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json')
+INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/valid_skip_set.json')
+INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/test_set.json')
+
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_plum_rqvae_beauty_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_eval_batches/')
+
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 2000
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ data = []
+ for i in range(len(mapping)):
+ data.append(mapping[str(i)])
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ data = Dataset.create_timestamp_based(
+ train_json_path=INTERACTIONS_TRAIN_PATH,
+ validation_json_path=INTERACTIONS_VALID_PATH,
+ test_json_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True
+ )
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sigir/Beauty/DatasetProcessing.ipynb b/sigir/Beauty/DatasetProcessing.ipynb
new file mode 100644
index 0000000..b49f4ab
--- /dev/null
+++ b/sigir/Beauty/DatasetProcessing.ipynb
@@ -0,0 +1,856 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "3bdb292f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from collections import defaultdict\n",
+ "\n",
+ "import json\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import pickle\n",
+ "import polars as pl\n",
+ "\n",
+ "from transformers import LlamaModel, LlamaTokenizer\n",
+ "\n",
+ "import torch\n",
+ "from torch.utils.data import DataLoader\n",
+ "\n",
+ "from tqdm import tqdm as tqdm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "66d9b312",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "interactions_dataset_path = '../data/Beauty/Beauty_5.json'\n",
+ "metadata_path = '../data/Beauty/metadata.json'\n",
+ "\n",
+ "interactions_output_json_path = '../data/Beauty_new/inter_new.json'\n",
+ "interactions_output_parquet_path = '../data/Beauty_new/inter_new.parquet'\n",
+ "embeddings_output_path = '../data/Beauty_new/content_embeddings.pkl'\n",
+ "item_ids_mapping_output_path = '../data/Beauty_new/item_ids_mapping.json'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "6ed4dffb",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of events: 198502\n"
+ ]
+ }
+ ],
+ "source": [
+ "df = defaultdict(list)\n",
+ "\n",
+ "with open(interactions_dataset_path, 'r') as f:\n",
+ " for line in f.readlines():\n",
+ " review = json.loads(line)\n",
+ " df['user_id'].append(review['reviewerID'])\n",
+ " df['item_id'].append(review['asin'])\n",
+ " df['timestamp'].append(review['unixReviewTime'])\n",
+ "\n",
+ "print(f'Number of events: {len(df[\"user_id\"])}')\n",
+ "\n",
+ "df = pl.from_dict(df)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "c26746c4",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 3)| user_id | item_id | timestamp |
|---|
| str | str | i64 |
| "A1YJEY40YUW4SE" | "7806397051" | 1391040000 |
| "A60XNB876KYML" | "7806397051" | 1397779200 |
| "A3G6XNM240RMWA" | "7806397051" | 1378425600 |
| "A1PQFP6SAJ6D80" | "7806397051" | 1386460800 |
| "A38FVHZTNQ271F" | "7806397051" | 1382140800 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌────────────────┬────────────┬────────────┐\n",
+ "│ user_id ┆ item_id ┆ timestamp │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ str ┆ str ┆ i64 │\n",
+ "╞════════════════╪════════════╪════════════╡\n",
+ "│ A1YJEY40YUW4SE ┆ 7806397051 ┆ 1391040000 │\n",
+ "│ A60XNB876KYML ┆ 7806397051 ┆ 1397779200 │\n",
+ "│ A3G6XNM240RMWA ┆ 7806397051 ┆ 1378425600 │\n",
+ "│ A1PQFP6SAJ6D80 ┆ 7806397051 ┆ 1386460800 │\n",
+ "│ A38FVHZTNQ271F ┆ 7806397051 ┆ 1382140800 │\n",
+ "└────────────────┴────────────┴────────────┘"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "adcf5713",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "filtered_df = df.clone()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c0bbf9ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Processing dataset to get core-5 state in case full dataset is provided\n",
+ "is_changed = True\n",
+ "threshold = 5\n",
+ "good_users = set()\n",
+ "good_items = set()\n",
+ "\n",
+ "while is_changed:\n",
+ " user_counts = filtered_df.group_by('user_id').agg(\n",
+ " pl.len().alias('user_count'),\n",
+ " )\n",
+ " item_counts = filtered_df.group_by('item_id').agg(\n",
+ " pl.len().alias('item_count'),\n",
+ " )\n",
+ "\n",
+ " good_users = user_counts.filter(pl.col('user_count') >= threshold).select(\n",
+ " 'user_id',\n",
+ " )\n",
+ " good_items = item_counts.filter(pl.col('item_count') >= threshold).select(\n",
+ " 'item_id',\n",
+ " )\n",
+ "\n",
+ " old_size = len(filtered_df)\n",
+ "\n",
+ " new_df = filtered_df.join(good_users, on='user_id', how='inner')\n",
+ " new_df = new_df.join(good_items, on='item_id', how='inner')\n",
+ "\n",
+ " new_size = len(new_df)\n",
+ "\n",
+ " filtered_df = new_df\n",
+ " is_changed = old_size != new_size\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "218a9348",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 3)| user_id | item_id | timestamp |
|---|
| i64 | i64 | i64 |
| 0 | 0 | 1391040000 |
| 1 | 0 | 1397779200 |
| 2 | 0 | 1378425600 |
| 3 | 0 | 1386460800 |
| 4 | 0 | 1382140800 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌─────────┬─────────┬────────────┐\n",
+ "│ user_id ┆ item_id ┆ timestamp │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ i64 ┆ i64 ┆ i64 │\n",
+ "╞═════════╪═════════╪════════════╡\n",
+ "│ 0 ┆ 0 ┆ 1391040000 │\n",
+ "│ 1 ┆ 0 ┆ 1397779200 │\n",
+ "│ 2 ┆ 0 ┆ 1378425600 │\n",
+ "│ 3 ┆ 0 ┆ 1386460800 │\n",
+ "│ 4 ┆ 0 ┆ 1382140800 │\n",
+ "└─────────┴─────────┴────────────┘"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "unique_values = filtered_df[\"user_id\"].unique(maintain_order=True).to_list()\n",
+ "user_ids_mapping = {value: i for i, value in enumerate(unique_values)}\n",
+ "\n",
+ "filtered_df = filtered_df.with_columns(\n",
+ " pl.col(\"user_id\").replace_strict(user_ids_mapping)\n",
+ ")\n",
+ "\n",
+ "unique_values = filtered_df[\"item_id\"].unique(maintain_order=True).to_list()\n",
+ "item_ids_mapping = {value: i for i, value in enumerate(unique_values)}\n",
+ "\n",
+ "filtered_df = filtered_df.with_columns(\n",
+ " pl.col(\"item_id\").replace_strict(item_ids_mapping)\n",
+ ")\n",
+ "\n",
+ "filtered_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "34604fe6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 2)| old_item_id | new_item_id |
|---|
| str | i64 |
| "7806397051" | 0 |
| "9759091062" | 1 |
| "9788072216" | 2 |
| "9790790961" | 3 |
| "9790794231" | 4 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 2)\n",
+ "┌─────────────┬─────────────┐\n",
+ "│ old_item_id ┆ new_item_id │\n",
+ "│ --- ┆ --- │\n",
+ "│ str ┆ i64 │\n",
+ "╞═════════════╪═════════════╡\n",
+ "│ 7806397051 ┆ 0 │\n",
+ "│ 9759091062 ┆ 1 │\n",
+ "│ 9788072216 ┆ 2 │\n",
+ "│ 9790790961 ┆ 3 │\n",
+ "│ 9790794231 ┆ 4 │\n",
+ "└─────────────┴─────────────┘"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "item_ids_mapping_df = pl.from_dict({\n",
+ " 'old_item_id': list(item_ids_mapping.keys()),\n",
+ " 'new_item_id': list(item_ids_mapping.values())\n",
+ "})\n",
+ "item_ids_mapping_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "99b54db807b9495c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(item_ids_mapping_output_path, 'w') as f:\n",
+ " json.dump({str(k): v for k, v in item_ids_mapping.items()}, f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "6017e65c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 3)| user_id | item_id | timestamp |
|---|
| i64 | i64 | i64 |
| 0 | 0 | 1391040000 |
| 1 | 0 | 1397779200 |
| 2 | 0 | 1378425600 |
| 3 | 0 | 1386460800 |
| 4 | 0 | 1382140800 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌─────────┬─────────┬────────────┐\n",
+ "│ user_id ┆ item_id ┆ timestamp │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ i64 ┆ i64 ┆ i64 │\n",
+ "╞═════════╪═════════╪════════════╡\n",
+ "│ 0 ┆ 0 ┆ 1391040000 │\n",
+ "│ 1 ┆ 0 ┆ 1397779200 │\n",
+ "│ 2 ┆ 0 ┆ 1378425600 │\n",
+ "│ 3 ┆ 0 ┆ 1386460800 │\n",
+ "│ 4 ┆ 0 ┆ 1382140800 │\n",
+ "└─────────┴─────────┴────────────┘"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "filtered_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "9efd1983",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "filtered_df = filtered_df.sort([\"user_id\", \"timestamp\"])\n",
+ "\n",
+ "grouped_filtered_df = filtered_df.group_by(\"user_id\", maintain_order=True).agg(pl.all())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "fd51c525",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 2)| old_item_id | new_item_id |
|---|
| str | i64 |
| "7806397051" | 0 |
| "9759091062" | 1 |
| "9788072216" | 2 |
| "9790790961" | 3 |
| "9790794231" | 4 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 2)\n",
+ "┌─────────────┬─────────────┐\n",
+ "│ old_item_id ┆ new_item_id │\n",
+ "│ --- ┆ --- │\n",
+ "│ str ┆ i64 │\n",
+ "╞═════════════╪═════════════╡\n",
+ "│ 7806397051 ┆ 0 │\n",
+ "│ 9759091062 ┆ 1 │\n",
+ "│ 9788072216 ┆ 2 │\n",
+ "│ 9790790961 ┆ 3 │\n",
+ "│ 9790794231 ┆ 4 │\n",
+ "└─────────────┴─────────────┘"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "item_ids_mapping_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "8b0821da",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 3)| user_id | item_id | timestamp |
|---|
| i64 | list[i64] | list[i64] |
| 0 | [6845, 7872, … 0] | [1318896000, 1318896000, … 1391040000] |
| 1 | [815, 10405, … 232] | [1392422400, 1396224000, … 1397779200] |
| 2 | [6049, 0, … 6608] | [1378425600, 1378425600, … 1400284800] |
| 3 | [5521, 5160, … 0] | [1379116800, 1380931200, … 1386460800] |
| 4 | [0, 10469, … 11389] | [1382140800, 1383523200, … 1388966400] |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌─────────┬─────────────────────┬─────────────────────────────────┐\n",
+ "│ user_id ┆ item_id ┆ timestamp │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ i64 ┆ list[i64] ┆ list[i64] │\n",
+ "╞═════════╪═════════════════════╪═════════════════════════════════╡\n",
+ "│ 0 ┆ [6845, 7872, … 0] ┆ [1318896000, 1318896000, … 139… │\n",
+ "│ 1 ┆ [815, 10405, … 232] ┆ [1392422400, 1396224000, … 139… │\n",
+ "│ 2 ┆ [6049, 0, … 6608] ┆ [1378425600, 1378425600, … 140… │\n",
+ "│ 3 ┆ [5521, 5160, … 0] ┆ [1379116800, 1380931200, … 138… │\n",
+ "│ 4 ┆ [0, 10469, … 11389] ┆ [1382140800, 1383523200, … 138… │\n",
+ "└─────────┴─────────────────────┴─────────────────────────────────┘"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "grouped_filtered_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "dc222d59",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Users count: 22363\n",
+ "Items count: 12101\n",
+ "Actions count: 198502\n",
+ "Avg user history len: 8.876358270357287\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('Users count:', filtered_df.select('user_id').unique().shape[0])\n",
+ "print('Items count:', filtered_df.select('item_id').unique().shape[0])\n",
+ "print('Actions count:', filtered_df.shape[0])\n",
+ "print('Avg user history len:', np.mean(list(map(lambda x: x[0], grouped_filtered_df.select(pl.col('item_id').list.len()).rows()))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "a272855d-84b2-4414-ba9f-62647e1151cf",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "shape: (5, 3)\n",
+ "┌─────┬─────────────────────┬─────────────────────────────────┐\n",
+ "│ uid ┆ item_ids ┆ timestamps │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ i64 ┆ list[i64] ┆ list[i64] │\n",
+ "╞═════╪═════════════════════╪═════════════════════════════════╡\n",
+ "│ 0 ┆ [6845, 7872, … 0] ┆ [1318896000, 1318896000, … 139… │\n",
+ "│ 1 ┆ [815, 10405, … 232] ┆ [1392422400, 1396224000, … 139… │\n",
+ "│ 2 ┆ [6049, 0, … 6608] ┆ [1378425600, 1378425600, … 140… │\n",
+ "│ 3 ┆ [5521, 5160, … 0] ┆ [1379116800, 1380931200, … 138… │\n",
+ "│ 4 ┆ [0, 10469, … 11389] ┆ [1382140800, 1383523200, … 138… │\n",
+ "└─────┴─────────────────────┴─────────────────────────────────┘\n"
+ ]
+ }
+ ],
+ "source": [
+ "inter_new = grouped_filtered_df.select([\n",
+ " pl.col(\"user_id\").alias(\"uid\"),\n",
+ " pl.col(\"item_id\").alias(\"item_ids\"),\n",
+ " pl.col(\"timestamp\").alias(\"timestamps\")\n",
+ "])\n",
+ "\n",
+ "print(inter_new.head())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "de5a853a-8ee2-42dd-a71a-6cc6f90d526c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Файл успешно сохранен: ../data/Beauty_new/inter_new.parquet\n"
+ ]
+ }
+ ],
+ "source": [
+ "output_path_parquet = interactions_output_parquet_path\n",
+ "inter_new.write_parquet(output_path_parquet)\n",
+ "\n",
+ "print(f\"Файл успешно сохранен: {output_path_parquet}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "d07a2e91",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "json_data = {}\n",
+ "for user_id, item_ids, _ in grouped_filtered_df.iter_rows():\n",
+ " json_data[user_id] = item_ids\n",
+ "\n",
+ "with open(interactions_output_json_path, 'w') as f:\n",
+ " json.dump(json_data, f, indent=2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "237523fa",
+ "metadata": {
+ "jp-MarkdownHeadingCollapsed": true
+ },
+ "source": [
+ "## Content embedding creation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "6361c7a5",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[19], line 5\u001b[0m, in \u001b[0;36mgetDF\u001b[0;34m(path)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(path, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m line \u001b[38;5;129;01min\u001b[39;00m \u001b[43mf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreadlines\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 6\u001b[0m df[i] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28meval\u001b[39m(line)\n",
+ "File \u001b[0;32m/usr/lib/python3.10/codecs.py:319\u001b[0m, in \u001b[0;36mBufferedIncrementalDecoder.decode\u001b[0;34m(self, input, final)\u001b[0m\n\u001b[1;32m 317\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m\n\u001b[0;32m--> 319\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mdecode\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m, final\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 320\u001b[0m \u001b[38;5;66;03m# decode input (taking the buffer into account)\u001b[39;00m\n\u001b[1;32m 321\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuffer \u001b[38;5;241m+\u001b[39m \u001b[38;5;28minput\u001b[39m\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: ",
+ "\nDuring handling of the above exception, another exception occurred:\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[19], line 11\u001b[0m\n\u001b[1;32m 7\u001b[0m i \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m pd\u001b[38;5;241m.\u001b[39mDataFrame\u001b[38;5;241m.\u001b[39mfrom_dict(df, orient\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mindex\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 11\u001b[0m df \u001b[38;5;241m=\u001b[39m \u001b[43mgetDF\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmetadata_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m df\u001b[38;5;241m.\u001b[39mhead()\n",
+ "Cell \u001b[0;32mIn[19], line 5\u001b[0m, in \u001b[0;36mgetDF\u001b[0;34m(path)\u001b[0m\n\u001b[1;32m 3\u001b[0m df \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(path, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m line \u001b[38;5;129;01min\u001b[39;00m \u001b[43mf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreadlines\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 6\u001b[0m df[i] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28meval\u001b[39m(line)\n\u001b[1;32m 7\u001b[0m i \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "def getDF(path):\n",
+ " i = 0\n",
+ " df = {}\n",
+ " with open(path, 'r') as f:\n",
+ " for line in f.readlines():\n",
+ " df[i] = eval(line)\n",
+ " i += 1\n",
+ "\n",
+ " return pd.DataFrame.from_dict(df, orient=\"index\")\n",
+ "\n",
+ "df = getDF(metadata_path)\n",
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "971fa89c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def preprocess(row: pd.Series):\n",
+ " row = row.fillna(\"None\")\n",
+ " return f\"Title: {row['title']}. Categories: {', '.join(row['categories'][0])}. Description: {row['description']}.\"\n",
+ "\n",
+ "\n",
+ "def get_data(metadata_df, item_ids_mapping_df):\n",
+ " filtered_df = metadata_df.join(\n",
+ " item_ids_mapping_df, \n",
+ " left_on=\"asin\", \n",
+ " right_on='old_item_id', \n",
+ " how=\"inner\"\n",
+ " ).select(pl.col('new_item_id'), pl.col('title'), pl.col('description'), pl.col('categories'))\n",
+ "\n",
+ " filtered_df = filtered_df.to_pandas()\n",
+ " filtered_df[\"combined_text\"] = filtered_df.apply(preprocess, axis=1)\n",
+ "\n",
+ " return filtered_df\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3b0dd5d5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = get_data(pl.from_pandas(df), item_ids_mapping)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "12e622ff",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device('cuda:6')\n",
+ "\n",
+ "model_name = \"huggyllama/llama-7b\"\n",
+ "tokenizer = LlamaTokenizer.from_pretrained(model_name)\n",
+ "\n",
+ "if tokenizer.pad_token is None:\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ "\n",
+ "model = LlamaModel.from_pretrained(model_name)\n",
+ "model = model.to(device)\n",
+ "model = model.eval()\n",
+ "\n",
+ "\n",
+ "class MyDataset:\n",
+ " def __init__(self, data):\n",
+ " self._data = list(zip(data.to_dict()['new_item_id'].values(), data.to_dict()['combined_text'].values()))\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self._data)\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " text = self._data[idx][1]\n",
+ " inputs = tokenizer(text, return_tensors=\"pt\", max_length=1024, truncation=True, padding=\"max_length\")\n",
+ " return {\n",
+ " 'item_id': self._data[idx][0],\n",
+ " 'input_ids': inputs['input_ids'][0],\n",
+ " 'attention_mask': inputs['attention_mask'][0]\n",
+ " }\n",
+ " \n",
+ "\n",
+ "dataset = MyDataset(data)\n",
+ "loader = DataLoader(dataset, batch_size=8, drop_last=False, shuffle=False, num_workers=10)\n",
+ "\n",
+ "\n",
+ "new_df = {\n",
+ " 'item_id': [],\n",
+ " 'embedding': []\n",
+ "}\n",
+ "\n",
+ "for batch in tqdm(loader):\n",
+ " with torch.inference_mode():\n",
+ " outputs = model(\n",
+ " input_ids=batch[\"input_ids\"].to(device), \n",
+ " attention_mask=batch[\"attention_mask\"].to(device)\n",
+ " )\n",
+ " embeddings = outputs.last_hidden_state\n",
+ " \n",
+ " embeddings = outputs.last_hidden_state # (bs, sl, ed)\n",
+ " embeddings[(~batch[\"attention_mask\"].bool())] = 0. # (bs, sl, ed)\n",
+ "\n",
+ " new_df['item_id'] += batch['item_id'].tolist()\n",
+ " new_df['embedding'] += embeddings.mean(dim=1).tolist() # (bs, ed)\n",
+ "\n",
+ "\n",
+ "with open(embeddings_output_path, 'wb') as f:\n",
+ " pickle.dump(new_df, f)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a6fffc4a-85f1-424e-b460-29e526df3317",
+ "metadata": {},
+ "source": [
+ "# Test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "1f922431-e3c1-4587-86d1-04a58eb8ffee",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "✓ Сохранено: ../data/Beauty_new/splits/inter_new_[0_1291403520.0).json\n",
+ "✓ Сохранено: ../data/Beauty_new/splits/inter_new_[1291403520.0_1329626880.0).json\n",
+ "✓ Сохранено: ../data/Beauty_new/splits/inter_new_[1329626880.0_1367850240.0).json\n",
+ "✓ Сохранено: ../data/Beauty_new/splits/inter_new_[1367850240.0_inf).json\n",
+ "Интервал 0: 3485 пользователей, 10350 взаимодействий\n",
+ "Интервал 1: 5751 пользователей, 15837 взаимодействий\n",
+ "Интервал 2: 13543 пользователей, 61954 взаимодействий\n",
+ "Интервал 3: 18811 пользователей, 110361 взаимодействий\n"
+ ]
+ }
+ ],
+ "source": [
+ "import polars as pl\n",
+ "import json\n",
+ "from typing import List, Dict\n",
+ "\n",
+ "def split_session_by_timestamps(\n",
+ " df: pl.DataFrame,\n",
+ " time_cutoffs: List[int],\n",
+ " output_dir: str = None,\n",
+ " return_dicts: bool = True\n",
+ ") -> List[Dict[int, List[int]]]:\n",
+ " \"\"\"\n",
+ " Разбивает датасет по временным интервалам и возвращает JSON-подобные словари.\n",
+ " \n",
+ " Args:\n",
+ " df: Polars DataFrame с колонками uid, item_ids (list), timestamps (list)\n",
+ " time_cutoffs: Лист временных точек для разбиения\n",
+ " output_dir: Директория для сохранения JSON файлов (опционально)\n",
+ " return_dicts: Возвращать ли словари (как json_data format)\n",
+ " \n",
+ " Returns:\n",
+ " Лист словарей в формате {user_id: [item_ids для интервала]}\n",
+ " \n",
+ " Example:\n",
+ " >>> df = pl.read_parquet(\"inter_new.parquet\")\n",
+ " >>> cutoffs = [1000000, 2000000, 3000000]\n",
+ " >>> dicts = split_session_by_timestamps(df, cutoffs, output_dir=\"./data\")\n",
+ " >>> # Получим 4 JSON файла за каждый интервал + последний\n",
+ " \"\"\"\n",
+ " \n",
+ " result_dicts = []\n",
+ " \n",
+ " def extract_interval(df_source, start, end=None):\n",
+ " \"\"\"Извлекает данные для одного временного интервала\"\"\"\n",
+ " q = df_source.lazy()\n",
+ " q = q.explode([\"item_ids\", \"timestamps\"])\n",
+ " \n",
+ " if end is not None:\n",
+ " q = q.filter(\n",
+ " (pl.col(\"timestamps\") >= start) & \n",
+ " (pl.col(\"timestamps\") < end)\n",
+ " )\n",
+ " else:\n",
+ " q = q.filter(\n",
+ " pl.col(\"timestamps\") >= start\n",
+ " )\n",
+ " \n",
+ " q = q.group_by(\"uid\").agg([\n",
+ " pl.col(\"item_ids\").alias(\"item_ids\")\n",
+ " ]).sort(\"uid\")\n",
+ " \n",
+ " return q.collect()\n",
+ " \n",
+ " # Генерируем интервалы\n",
+ " intervals = []\n",
+ " current_start = 0\n",
+ " for cutoff in time_cutoffs:\n",
+ " intervals.append((current_start, cutoff))\n",
+ " current_start = cutoff\n",
+ " # Последний интервал от последнего cutoff до бесконечности\n",
+ " intervals.append((current_start, None))\n",
+ " \n",
+ " # Обрабатываем каждый интервал\n",
+ " for start, end in intervals:\n",
+ " subset = extract_interval(df, start, end)\n",
+ " \n",
+ " # Конвертируем в JSON-подобный словарь\n",
+ " json_dict = {}\n",
+ " for user_id, item_ids in subset.iter_rows():\n",
+ " json_dict[user_id] = item_ids\n",
+ " \n",
+ " result_dicts.append(json_dict)\n",
+ " \n",
+ " # Опционально сохраняем файлы\n",
+ " if output_dir:\n",
+ " if end is not None:\n",
+ " filename = f\"inter_new_[{start}_{end}).json\"\n",
+ " else:\n",
+ " filename = f\"inter_new_[{start}_inf).json\"\n",
+ " \n",
+ " filepath = f\"{output_dir}/{filename}\"\n",
+ " with open(filepath, 'w') as f:\n",
+ " json.dump(json_dict, f, indent=2)\n",
+ " \n",
+ " print(f\"✓ Сохранено: {filepath}\")\n",
+ " \n",
+ " return result_dicts\n",
+ "\n",
+ "\n",
+ "# ==========================================\n",
+ "# Использование в ноутбуке\n",
+ "# ==========================================\n",
+ "\n",
+ "# Загружаем сохраненный Parquet файл\n",
+ "df = pl.read_parquet(interactions_output_parquet_path)\n",
+ "\n",
+ "# Определяем временные точки разбиения (можно использовать процентили или конкретные даты)\n",
+ "# Например: разбить на 70%, 80%, 90% от времени\n",
+ "df_timestamps = df.select(\n",
+ " pl.col(\"timestamps\").explode()\n",
+ ")\n",
+ "min_time = df_timestamps.select(pl.col(\"timestamps\").min()).item()\n",
+ "max_time = df_timestamps.select(pl.col(\"timestamps\").max()).item()\n",
+ "\n",
+ "# Разделяем на 4 части (train/val/test/test_final)\n",
+ "cutoffs = [\n",
+ " min_time + (max_time - min_time) * 0.7, # 70%\n",
+ " min_time + (max_time - min_time) * 0.8, # 80%\n",
+ " min_time + (max_time - min_time) * 0.9, # 90%\n",
+ "]\n",
+ "\n",
+ "# Применяем функцию\n",
+ "result_dicts = split_session_by_timestamps(\n",
+ " df, \n",
+ " cutoffs,\n",
+ " output_dir=\"../data/Beauty_new/splits\" # Опционально\n",
+ ")\n",
+ "\n",
+ "# Выводим статистику\n",
+ "for i, json_dict in enumerate(result_dicts):\n",
+ " total_interactions = sum(len(items) for items in json_dict.values())\n",
+ " print(f\"Интервал {i}: {len(json_dict)} пользователей, {total_interactions} взаимодействий\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "73b5ec51-4d94-4021-9a21-3f4345ecdd26",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "✓ Сохранено: ../data/Beauty_new/splits/inter_new_[0_inf).json\n"
+ ]
+ }
+ ],
+ "source": [
+ "split_session_by_timestamps(\n",
+ " df, \n",
+ " [],\n",
+ " output_dir=\"../data/Beauty_new/splits\"\n",
+ ")\n",
+ "None"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/sigir/Beauty/exps_data.ipynb b/sigir/Beauty/exps_data.ipynb
new file mode 100644
index 0000000..2625231
--- /dev/null
+++ b/sigir/Beauty/exps_data.ipynb
@@ -0,0 +1,921 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "e2462a97-6705-44e1-a232-4dd78a5dfc85",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import polars as pl\n",
+ "import json\n",
+ "from typing import List, Dict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "fd38624d-5796-4aa5-929f-7e82c5544f6c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "interactions_output_parquet_path = '/home/jovyan/IRec/sigir/Beauty_new/inter_new.parquet'\n",
+ "# 1. Загружаем\n",
+ "df = pl.read_parquet(interactions_output_parquet_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "ee127317-66b8-4f22-9109-94bcb8b1f1ae",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def split_session_by_timestamps(\n",
+ " df: pl.DataFrame,\n",
+ " time_cutoffs: List[int],\n",
+ " output_dir: str = None,\n",
+ " return_dicts: bool = True\n",
+ ") -> List[Dict[int, List[int]]]:\n",
+ " \"\"\"\n",
+ " Разбивает датасет по временным интервалам и возвращает JSON-подобные словари.\n",
+ " \n",
+ " Args:\n",
+ " df: Polars DataFrame с колонками uid, item_ids (list), timestamps (list)\n",
+ " time_cutoffs: Лист временных точек для разбиения\n",
+ " output_dir: Директория для сохранения JSON файлов (опционально)\n",
+ " return_dicts: Возвращать ли словари (как json_data format)\n",
+ " \n",
+ " Returns:\n",
+ " Лист словарей в формате {user_id: [item_ids для интервала]}\n",
+ " \n",
+ " Example:\n",
+ " >>> df = pl.read_parquet(\"inter_new.parquet\")\n",
+ " >>> cutoffs = [1000000, 2000000, 3000000]\n",
+ " >>> dicts = split_session_by_timestamps(df, cutoffs, output_dir=\"./data\")\n",
+ " >>> # Получим 4 JSON файла за каждый интервал + последний\n",
+ " \"\"\"\n",
+ " \n",
+ " result_dicts = []\n",
+ " \n",
+ " def extract_interval(df_source, start, end=None):\n",
+ " \"\"\"Извлекает данные для одного временного интервала\"\"\"\n",
+ " q = df_source.lazy()\n",
+ " q = q.explode([\"item_ids\", \"timestamps\"])\n",
+ " \n",
+ " if end is not None:\n",
+ " q = q.filter(\n",
+ " (pl.col(\"timestamps\") >= start) & \n",
+ " (pl.col(\"timestamps\") < end)\n",
+ " )\n",
+ " else:\n",
+ " q = q.filter(\n",
+ " pl.col(\"timestamps\") >= start\n",
+ " )\n",
+ " \n",
+ " q = q.group_by(\"uid\").agg([\n",
+ " pl.col(\"item_ids\").alias(\"item_ids\")\n",
+ " ]).sort(\"uid\")\n",
+ " \n",
+ " return q.collect()\n",
+ " \n",
+ " # Генерируем интервалы\n",
+ " intervals = []\n",
+ " current_start = 0\n",
+ " for cutoff in time_cutoffs:\n",
+ " intervals.append((current_start, cutoff))\n",
+ " current_start = cutoff\n",
+ " # Последний интервал от последнего cutoff до бесконечности\n",
+ " intervals.append((current_start, None))\n",
+ " \n",
+ " # Обрабатываем каждый интервал\n",
+ " for start, end in intervals:\n",
+ " subset = extract_interval(df, start, end)\n",
+ " \n",
+ " # Конвертируем в JSON-подобный словарь\n",
+ " json_dict = {}\n",
+ " for user_id, item_ids in subset.iter_rows():\n",
+ " json_dict[user_id] = item_ids\n",
+ " \n",
+ " result_dicts.append(json_dict)\n",
+ " \n",
+ " # Опционально сохраняем файлы\n",
+ " if output_dir:\n",
+ " if end is not None:\n",
+ " filename = f\"inter_new_[{start}_{end}).json\"\n",
+ " else:\n",
+ " filename = f\"inter_new_[{start}_inf).json\"\n",
+ " \n",
+ " filepath = f\"{output_dir}/{filename}\"\n",
+ " with open(filepath, 'w') as f:\n",
+ " json.dump(json_dict, f, indent=2)\n",
+ " \n",
+ " print(f\"✓ Сохранено: {filepath}\")\n",
+ " \n",
+ " return result_dicts"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "efc8b582-dd8a-4299-9c49-de906251de8a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Cutoffs: [1402444800, 1403654400, 1404864000]\n",
+ "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/splits/test_splits/inter_new_[0_1402444800).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/splits/test_splits/inter_new_[1402444800_1403654400).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/splits/test_splits/inter_new_[1403654400_1404864000).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/splits/test_splits/inter_new_[1404864000_inf).json\n",
+ "Part 0 [Base]: 22029 users\n",
+ "Part 1 [Week -6]: 1854 users\n",
+ "Part 2 [Week -4]: 1945 users\n",
+ "Part 3 [Week -2]: 1381 users\n"
+ ]
+ }
+ ],
+ "source": [
+ "global_max_time = df.select(\n",
+ " pl.col(\"timestamps\").explode().max()\n",
+ ").item()\n",
+ "\n",
+ "# 3. Размер окна (неделя)\n",
+ "days_val = 14\n",
+ "window_sec = days_val * 24 * 3600 \n",
+ "\n",
+ "# 4. Три отсечки с конца\n",
+ "cutoff_test_start = global_max_time - window_sec # T - 2w\n",
+ "cutoff_val_start = global_max_time - 2 * window_sec # T - 4w\n",
+ "cutoff_gap_start = global_max_time - 3 * window_sec # T - 6w\n",
+ "\n",
+ "cutoffs = [\n",
+ " int(cutoff_gap_start), # Граница Part 0 | Part 1\n",
+ " int(cutoff_val_start), # Граница Part 1 | Part 2\n",
+ " int(cutoff_test_start) # Граница Part 2 | Part 3\n",
+ "]\n",
+ "\n",
+ "print(f\"Cutoffs: {cutoffs}\")\n",
+ "\n",
+ "# 5. Разбиваем на 4 файла\n",
+ "# Part 0: Deep History\n",
+ "# Part 1: Pre-Validation (нужна для s1, но выкидывается для 'совсем короткого' s0?)\n",
+ "# *В вашем случае 4.2 просто 'на неделю короче', так что Part 1 все равно войдет в трейн Semantics, \n",
+ "# а выкинется только Part 2. Но если захотите еще короче - можно выкинуть и Part 1.*\n",
+ "# Part 2: Validation (Есть в 4.1, НЕТ в 4.2 для Semantics)\n",
+ "# Part 3: Test\n",
+ "\n",
+ "split_files = split_session_by_timestamps(\n",
+ " df, \n",
+ " cutoffs, \n",
+ " output_dir=\"/home/jovyan/IRec/sigir/Beauty_new/splits/test_splits\"\n",
+ ")\n",
+ "\n",
+ "names = [\"Base\", \"Week -6\", \"Week -4\", \"Week -2\"]\n",
+ "for i, d in enumerate(split_files):\n",
+ " print(f\"Part {i} [{names[i]}]: {len(d)} users\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "d5ba172e-b430-40a3-a4fa-64366d02a015",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def merge_and_save(parts_to_merge, dirr, output_name):\n",
+ " merged = {}\n",
+ " print(f\"Merging {len(parts_to_merge)} files into {output_name}...\")\n",
+ " \n",
+ " for part in parts_to_merge:\n",
+ " # with open(fp, 'r') as f:\n",
+ " # part = json.load(f)\n",
+ " for uid, items in part.items():\n",
+ " if uid not in merged:\n",
+ " merged[uid] = []\n",
+ " merged[uid].extend(items)\n",
+ " \n",
+ " out_path = f\"{dirr}/{output_name}\"\n",
+ " with open(out_path, 'w') as f:\n",
+ " json.dump(merged, f)\n",
+ " print(f\"✓ Done: {out_path} (Users: {len(merged)})\")\n",
+ "\n",
+ "\n",
+ "# p0, p1, p2, p3 = split_files[0], split_files[1], split_files[2], split_files[3]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "d116b7e0-9bf9-4104-86a0-69788a70cc14",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Merging 2 files into exp_4_inter_tiger_train.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json (Users: 22129)\n",
+ "Merging 2 files into exp_4.1_inter_semantics_train.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json (Users: 22129)\n",
+ "Merging 1 files into exp_4.2_inter_semantics_train_short.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.2_inter_semantics_train_short.json (Users: 22029)\n",
+ "Merging 3 files into exp_4.3_inter_semantics_train_leak.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.3_inter_semantics_train_leak.json (Users: 22265)\n",
+ "Merging 1 files into test_set.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/test_set.json (Users: 1381)\n",
+ "Merging 1 files into valid_skip_set.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/valid_skip_set.json (Users: 1945)\n",
+ "\n",
+ "All done!\n"
+ ]
+ }
+ ],
+ "source": [
+ "EXP_DIR = \"../sigir/Beauty_new/splits/exp_data\"\n",
+ "\n",
+ "# Tiger: P0+P1\n",
+ "merge_and_save([p0, p1], EXP_DIR, \"exp_4_inter_tiger_train.json\")\n",
+ "\n",
+ "# 1. Exp 4.1 (Standard)\n",
+ "# Semantics: P0+P1 (Всё кроме пропуска и теста)\n",
+ "merge_and_save([p0, p1], EXP_DIR, \"exp_4.1_inter_semantics_train.json\")\n",
+ "\n",
+ "# 2. Exp 4.2 (Short Semantics)\n",
+ "# Semantics: P0 (Короче на неделю, без P2)\n",
+ "merge_and_save([p0], EXP_DIR, \"exp_4.2_inter_semantics_train_short.json\")\n",
+ "\n",
+ "# 3. Exp 4.3 (Leak)\n",
+ "# Semantics: P0+P1+P2 (Видит валидацию)\n",
+ "merge_and_save([p0, p1, p2], EXP_DIR, \"exp_4.3_inter_semantics_train_leak.json\")\n",
+ "\n",
+ "# 4. Test Set (тест всех моделей)\n",
+ "merge_and_save([p3], EXP_DIR, \"test_set.json\")\n",
+ "\n",
+ "# 4. Valid Set (пропуск, имитируется разница трейна и теста чтобы потом дообучать)\n",
+ "merge_and_save([p2], EXP_DIR, \"valid_skip_set.json\")\n",
+ "\n",
+ "print(\"\\nAll done!\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "9ae1d1e5-567d-471a-8f83-4039ecacc8d2",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Merging 4 files into all_set.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/all_set.json (Users: 22363)\n"
+ ]
+ }
+ ],
+ "source": [
+ "merge_and_save([p0, p1, p2, p3], EXP_DIR, \"all_set.json\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "328de16c-f61d-45be-8a72-5f0eaef612e8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Проверка Train сетов (должны быть префиксами):\n",
+ "✅ [ПРЕФИКСЫ] Все 22129 массивов ОК. Полных совпадений: 19410\n",
+ "✅ [ПРЕФИКСЫ] Все 22029 массивов ОК. Полных совпадений: 18191\n",
+ "✅ [ПРЕФИКСЫ] Все 22265 массивов ОК. Полных совпадений: 20982\n",
+ "✅ [ПРЕФИКСЫ] Все 22129 массивов ОК. Полных совпадений: 19410\n",
+ "\n",
+ "Проверка Test сета (должен быть суффиксом):\n",
+ "✅ [СУФФИКСЫ] Все 1381 массивов ОК. Полных совпадений: 98\n",
+ "\n",
+ "(Контроль) Проверка Test сета как префикса (должна упасть):\n",
+ "❌ [ПРЕФИКСЫ] Найдено 1283 ошибок.\n"
+ ]
+ }
+ ],
+ "source": [
+ "with open(\"../data/Beauty/inter_new.json\", 'r') as f:\n",
+ " old_inter_new = json.load(f)\n",
+ "\n",
+ "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json\", 'r') as ff:\n",
+ " first_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.2_inter_semantics_train_short.json\", 'r') as ff:\n",
+ " second_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.3_inter_semantics_train_leak.json\", 'r') as ff:\n",
+ " third_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json\", 'r') as ff:\n",
+ " tiger_sem = json.load(ff)\n",
+ "\n",
+ "with open(\"../sigir/Beauty_new/splits/exp_data/test_set.json\", 'r') as ff:\n",
+ " test_sem = json.load(ff)\n",
+ "\n",
+ "def check_prefix_match(full_data, subset_data, check_suffix=False):\n",
+ " \"\"\"\n",
+ " check_suffix=True включит режим проверки суффиксов (для теста).\n",
+ " \"\"\"\n",
+ " mismatch_count = 0\n",
+ " full_match_count = 0\n",
+ " \n",
+ " # Итерируемся по ключам сабсета, так как в full_data может быть больше юзеров\n",
+ " for user, sub_items in subset_data.items():\n",
+ " \n",
+ " # Проверяем есть ли такой юзер в исходнике\n",
+ " if user not in full_data:\n",
+ " print(f\"⚠ Юзер {user} не найден в исходном файле!\")\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " full_items = full_data[user]\n",
+ " \n",
+ " # Логика для проверки ПРЕФИКСА (начало совпадает)\n",
+ " if not check_suffix:\n",
+ " if len(sub_items) > len(full_items):\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " # Сравниваем начало full с sub\n",
+ " if full_items[:len(sub_items)] == sub_items:\n",
+ " if len(full_items) == len(sub_items):\n",
+ " full_match_count += 1\n",
+ " else:\n",
+ " mismatch_count += 1\n",
+ " \n",
+ " # Логика для проверки СУФФИКСА (конец совпадает - для теста)\n",
+ " else:\n",
+ " if len(sub_items) > len(full_items):\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " # Сравниваем конец full с sub\n",
+ " # Срез [-len:] берет последние N элементов\n",
+ " if full_items[-len(sub_items):] == sub_items:\n",
+ " if len(full_items) == len(sub_items):\n",
+ " full_match_count += 1\n",
+ " else:\n",
+ " mismatch_count += 1\n",
+ "\n",
+ " mode = \"СУФФИКСЫ\" if check_suffix else \"ПРЕФИКСЫ\"\n",
+ " \n",
+ " if mismatch_count == 0:\n",
+ " print(f\"✅ [{mode}] Все {len(subset_data)} массивов ОК. Полных совпадений: {full_match_count}\")\n",
+ " else:\n",
+ " print(f\"❌ [{mode}] Найдено {mismatch_count} ошибок.\")\n",
+ "\n",
+ "# --- Запуск проверок ---\n",
+ "print(\"Проверка Train сетов (должны быть префиксами):\")\n",
+ "check_prefix_match(old_inter_new, first_sem)\n",
+ "check_prefix_match(old_inter_new, second_sem)\n",
+ "check_prefix_match(old_inter_new, third_sem)\n",
+ "check_prefix_match(old_inter_new, tiger_sem)\n",
+ "\n",
+ "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n",
+ "check_prefix_match(old_inter_new, test_sem, check_suffix=True)\n",
+ "\n",
+ "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n",
+ "check_prefix_match(old_inter_new, test_sem, check_suffix=False)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "0715adfd",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "name 'суа' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mсуа\u001b[49m\n",
+ "\u001b[0;31mNameError\u001b[0m: name 'суа' is not defined"
+ ]
+ }
+ ],
+ "source": [
+ "суа"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "f2df507d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--- Статистика по временным интервалам (Fixed Time Window) ---\n",
+ "Part 0 [Base]: 186516 events (Start -> 2014-06-11)\n",
+ "Part 1 [Gap (Week -6)]: 4073 events (2014-06-11 -> 2014-06-25)\n",
+ "Part 2 [Pre-Valid (Week -4)]: 4730 events (2014-06-25 -> 2014-07-09)\n",
+ "Part 3 [Test (Week -2)]: 3183 events (2014-07-09 -> Inf)\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import polars as pl\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from datetime import datetime\n",
+ "\n",
+ "\n",
+ "# 2. Статистика по текущему временному разбиению\n",
+ "global_max_time = df.select(pl.col(\"timestamps\").explode().max()).item()\n",
+ "days_val = 14\n",
+ "window_sec = days_val * 24 * 3600 \n",
+ "\n",
+ "cutoffs = [\n",
+ " int(global_max_time - 3 * window_sec),\n",
+ " int(global_max_time - 2 * window_sec),\n",
+ " int(global_max_time - 1 * window_sec)\n",
+ "]\n",
+ "\n",
+ "print(\"--- Статистика по временным интервалам (Fixed Time Window) ---\")\n",
+ "intervals = [0] + cutoffs + [None]\n",
+ "labels = [\"Base\", \"Gap (Week -6)\", \"Pre-Valid (Week -4)\", \"Test (Week -2)\"]\n",
+ "\n",
+ "# Считаем события в каждом интервале\n",
+ "counts = []\n",
+ "for i in range(len(intervals)-1):\n",
+ " start, end = intervals[i], intervals[i+1]\n",
+ " \n",
+ " q = df.lazy().explode([\"timestamps\"])\n",
+ " if end is not None:\n",
+ " q = q.filter((pl.col(\"timestamps\") >= start) & (pl.col(\"timestamps\") < end))\n",
+ " else:\n",
+ " q = q.filter(pl.col(\"timestamps\") >= start)\n",
+ " \n",
+ " count = q.select(pl.len()).collect().item()\n",
+ " counts.append(count)\n",
+ " \n",
+ " end_str = datetime.fromtimestamp(end).strftime('%Y-%m-%d') if end else \"Inf\"\n",
+ " start_str = datetime.fromtimestamp(start).strftime('%Y-%m-%d') if start > 0 else \"Start\"\n",
+ " \n",
+ " print(f\"Part {i} [{labels[i]}]: {count} events ({start_str} -> {end_str})\")\n",
+ "\n",
+ "# 3. Гистограмма распределения событий во времени\n",
+ "all_timestamps = df.select(pl.col(\"timestamps\").explode()).to_series().to_numpy()\n",
+ "\n",
+ "plt.figure(figsize=(12, 6))\n",
+ "plt.hist(all_timestamps, bins=100, color='skyblue', alpha=0.7, label='Events')\n",
+ "\n",
+ "# Рисуем линии отсечек\n",
+ "colors = ['red', 'orange', 'green']\n",
+ "for cutoff, color, label in zip(cutoffs, colors, labels[1:]):\n",
+ " plt.axvline(x=cutoff, color=color, linestyle='--', linewidth=2, label=f'Cutoff: {label}')\n",
+ "\n",
+ "plt.title(\"Распределение взаимодействий во времени\")\n",
+ "plt.xlabel(\"Timestamp\")\n",
+ "plt.ylabel(\"Количество событий\")\n",
+ "plt.legend()\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.show()\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "901e7400",
+ "metadata": {},
+ "source": [
+ "# QUANTILE CUTOFF"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "8c691891",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_quantile_cutoffs(df, num_parts=4, base_ratio=None):\n",
+ " \"\"\"\n",
+ " Считает cutoffs так, чтобы разбить данные на части.\n",
+ " \n",
+ " Args:\n",
+ " num_parts: На сколько частей делить \"хвост\" истории.\n",
+ " base_ratio: Какую долю данных отдать в Base (самую первую часть). \n",
+ " Если None, делит всё поровну.\n",
+ " \"\"\"\n",
+ " # Достаем все таймстемпы в один плоский массив\n",
+ " # Это может занять память, если данных очень много (>100M), но для Beauty (2M) это ок\n",
+ " all_ts = df.select(pl.col(\"timestamps\").explode()).to_series().sort()\n",
+ " total_events = len(all_ts)\n",
+ " \n",
+ " print(f\"Всего событий: {total_events}\")\n",
+ " \n",
+ " cutoffs = []\n",
+ " \n",
+ " if base_ratio:\n",
+ " # Сценарий: Base занимает X% (например 80%), а остаток делим поровну на 3 части (Valid, Gap, Test)\n",
+ " # Остаток = 1 - base_ratio\n",
+ " # Каждая малая часть = (1 - base_ratio) / num_parts_tail\n",
+ " \n",
+ " base_idx = int(total_events * base_ratio)\n",
+ " cutoffs.append(all_ts[base_idx]) # Первый cutoff отделяет Base\n",
+ " \n",
+ " remaining_events = total_events - base_idx\n",
+ " part_size = remaining_events // num_parts # Делим остаток на 3 части (P1, P2, P3)\n",
+ " \n",
+ " current_idx = base_idx\n",
+ " for _ in range(num_parts-1): # Нам нужно еще 2 границы, чтобы получить 3 части\n",
+ " current_idx += part_size\n",
+ " cutoffs.append(all_ts[current_idx])\n",
+ " \n",
+ " else:\n",
+ " # Сценарий: Просто делим всё на N равных частей\n",
+ " step = total_events // num_parts\n",
+ " for i in range(1, num_parts):\n",
+ " idx = i * step\n",
+ " cutoffs.append(all_ts[idx])\n",
+ " \n",
+ " return cutoffs\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "13c1466f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Всего событий: 198502\n",
+ "\n",
+ "--- Новые Cutoffs (по количеству событий) ---\n",
+ "Cutoffs: [1394150400, 1397001600, 1399939200, 1403049600]\n",
+ "[0, 1394150400, 1397001600, 1399939200, 1403049600, None]\n",
+ "Проверка количества событий в новых частях:\n",
+ "Part 0: 158689 events\n",
+ "Part 1: 9965 events\n",
+ "Part 2: 9701 events\n",
+ "Part 3: 10137 events\n",
+ "Part 4: 10010 events\n"
+ ]
+ }
+ ],
+ "source": [
+ "equal_event_cutoffs = get_quantile_cutoffs(df, num_parts=4, base_ratio=0.8)\n",
+ "\n",
+ "print(\"\\n--- Новые Cutoffs (по количеству событий) ---\")\n",
+ "print(f\"Cutoffs: {equal_event_cutoffs}\")\n",
+ "\n",
+ "# Проверка распределения\n",
+ "intervals_eq = [0] + equal_event_cutoffs + [None]\n",
+ "print(intervals_eq)\n",
+ "print(\"Проверка количества событий в новых частях:\")\n",
+ "for i in range(len(intervals_eq)-1):\n",
+ " start, end = intervals_eq[i], intervals_eq[i+1]\n",
+ " q = df.lazy().explode([\"timestamps\"])\n",
+ " if end:\n",
+ " q = q.filter((pl.col(\"timestamps\") >= start) & (pl.col(\"timestamps\") < end))\n",
+ " else:\n",
+ " q = q.filter(pl.col(\"timestamps\") >= start)\n",
+ " count = q.select(pl.len()).collect().item()\n",
+ " print(f\"Part {i}: {count} events\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "4e7f7b46",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/raw/inter_new_[0_1394150400).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/raw/inter_new_[1394150400_1399939200).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/raw/inter_new_[1399939200_1403049600).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/raw/inter_new_[1403049600_inf).json\n",
+ "0 Base 20825 158689 \n",
+ "1 Gap 6816 19666 \n",
+ "2 Valid 3817 10137 \n",
+ "3 Test 3626 10010 \n"
+ ]
+ }
+ ],
+ "source": [
+ "new_split_files = split_session_by_timestamps(\n",
+ " df, \n",
+ " [1394150400, 1399939200, 1403049600], \n",
+ " output_dir=\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/raw\"\n",
+ ")\n",
+ "\n",
+ "names = [\"Base\", \"Gap\", \"Valid\", \"Test\"]\n",
+ "for i, d in enumerate(new_split_files):\n",
+ " num_users = len(d)\n",
+ " \n",
+ " num_events = sum(len(items) for items in d.values())\n",
+ " \n",
+ " print(f\"{i:<10} {names[i]:<10} {num_users:<10} {num_events:<10}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "82fd2bca",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Merging 2 files into exp_4_0.9_inter_tiger_train.json...\n",
+ "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json (Users: 21760)\n",
+ "Merging 2 files into exp_4-1_0.9_inter_semantics_train.json...\n",
+ "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json (Users: 21760)\n",
+ "Merging 1 files into exp_4-2_0.8_inter_semantics_train.json...\n",
+ "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json (Users: 20825)\n",
+ "Merging 3 files into exp_4-3_0.95_inter_semantics_train.json...\n",
+ "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-3_0.95_inter_semantics_train.json (Users: 22079)\n",
+ "Merging 1 files into test_set.json...\n",
+ "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/test_set.json (Users: 3626)\n",
+ "Merging 1 files into valid_set.json...\n",
+ "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/valid_set.json (Users: 3817)\n",
+ "Merging 4 files into all_set.json...\n",
+ "✓ Done: /home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/all_set.json (Users: 22363)\n",
+ "All done!\n"
+ ]
+ }
+ ],
+ "source": [
+ "EXP_DIR = \"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps\"\n",
+ "\n",
+ "base_p, gap_p, valid_p, test_p = new_split_files[0], new_split_files[1], new_split_files[2], new_split_files[3]\n",
+ "\n",
+ "# Tiger: base + gap\n",
+ "merge_and_save([base_p, gap_p], EXP_DIR, \"exp_4_0.9_inter_tiger_train.json\")\n",
+ "\n",
+ "# 1. Exp 4.1 (Standard)\n",
+ "# Semantics: base + gap (Всё кроме валидации и теста)\n",
+ "merge_and_save([base_p, gap_p], EXP_DIR, \"exp_4-1_0.9_inter_semantics_train.json\")\n",
+ "\n",
+ "# 2. Exp 4.2 (Short Semantics)\n",
+ "# Semantics: base (Короче на пропуск, без gap)\n",
+ "merge_and_save([base_p], EXP_DIR, \"exp_4-2_0.8_inter_semantics_train.json\")\n",
+ "\n",
+ "# 3. Exp 4.3 (Leak)\n",
+ "# Semantics: base + gap + valid (Видит валидацию)\n",
+ "merge_and_save([base_p, gap_p, valid_p], EXP_DIR, \"exp_4-3_0.95_inter_semantics_train.json\")\n",
+ "\n",
+ "# 4. Test Set (тест всех моделей)\n",
+ "merge_and_save([test_p], EXP_DIR, \"test_set.json\")\n",
+ "\n",
+ "# 4. Valid Set (валидационный набор)\n",
+ "merge_and_save([valid_p], EXP_DIR, \"valid_set.json\")\n",
+ "\n",
+ "# 4. All Set (все данные)\n",
+ "merge_and_save([base_p, gap_p, valid_p, test_p], EXP_DIR, \"all_set.json\")\n",
+ "\n",
+ "print(\"All done!\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "d34b1c55",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Проверка Train сетов (должны быть префиксами):\n",
+ "доля событий всего 0.90:\n",
+ "✅ [ПРЕФИКСЫ] Все 21760 массивов ОК. Полных совпадений: 16175\n",
+ "доля событий всего 0.80:\n",
+ "✅ [ПРЕФИКСЫ] Все 20825 массивов ОК. Полных совпадений: 12129\n",
+ "доля событий всего 0.95:\n",
+ "✅ [ПРЕФИКСЫ] Все 22079 массивов ОК. Полных совпадений: 18737\n",
+ "доля событий всего 0.90:\n",
+ "✅ [ПРЕФИКСЫ] Все 21760 массивов ОК. Полных совпадений: 16175\n",
+ "\n",
+ "Проверка Test сета (должен быть суффиксом):\n",
+ "доля событий всего 0.05:\n",
+ "✅ [СУФФИКСЫ] Все 3626 массивов ОК. Полных совпадений: 284\n",
+ "\n",
+ "(Контроль) Проверка Test сета как префикса (должна упасть):\n",
+ "доля событий всего 0.05:\n",
+ "❌ [ПРЕФИКСЫ] Найдено 3342 ошибок.\n",
+ "доля событий всего 1.00:\n",
+ "✅ [ПРЕФИКСЫ] Все 22363 массивов ОК. Полных совпадений: 22363\n"
+ ]
+ }
+ ],
+ "source": [
+ "with open(\"/home/jovyan/IRec/data/Beauty/inter_new.json\", 'r') as f:\n",
+ " old_inter_new = json.load(f)\n",
+ "\n",
+ "with open(\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json\", 'r') as ff:\n",
+ " first_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json\", 'r') as ff:\n",
+ " second_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4-3_0.95_inter_semantics_train.json\", 'r') as ff:\n",
+ " third_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json\", 'r') as ff:\n",
+ " tiger_sem = json.load(ff)\n",
+ "\n",
+ "with open(\"//home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/test_set.json\", 'r') as ff:\n",
+ " test_sem = json.load(ff)\n",
+ "\n",
+ "with open(\"/home/jovyan/IRec/sigir/Beauty_new/updated_quantile_splits/merged_for_exps/all_set.json\", 'r') as ff:\n",
+ " all_test_data = json.load(ff)\n",
+ "\n",
+ "def check_prefix_match(full_data, subset_data, check_suffix=False):\n",
+ " \"\"\"\n",
+ " check_suffix=True включит режим проверки суффиксов (для теста).\n",
+ " \"\"\"\n",
+ " mismatch_count = 0\n",
+ " full_match_count = 0\n",
+ "\n",
+ " num_events_full_data = sum(len(items) for items in full_data.values())\n",
+ " num_events_subset_data = sum(len(items) for items in subset_data.values())\n",
+ " print(f\"доля событий всего {(num_events_subset_data/num_events_full_data):.2f}:\")\n",
+ " \n",
+ " # Итерируемся по ключам сабсета, так как в full_data может быть больше юзеров\n",
+ " for user, sub_items in subset_data.items():\n",
+ " \n",
+ " # Проверяем есть ли такой юзер в исходнике\n",
+ " if user not in full_data:\n",
+ " print(f\"⚠ Юзер {user} не найден в исходном файле!\")\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " full_items = full_data[user]\n",
+ " \n",
+ " # Логика для проверки ПРЕФИКСА (начало совпадает)\n",
+ " if not check_suffix:\n",
+ " if len(sub_items) > len(full_items):\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " # Сравниваем начало full с sub\n",
+ " if full_items[:len(sub_items)] == sub_items:\n",
+ " if len(full_items) == len(sub_items):\n",
+ " full_match_count += 1\n",
+ " else:\n",
+ " mismatch_count += 1\n",
+ " \n",
+ " # Логика для проверки СУФФИКСА (конец совпадает - для теста)\n",
+ " else:\n",
+ " if len(sub_items) > len(full_items):\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " # Сравниваем конец full с sub\n",
+ " # Срез [-len:] берет последние N элементов\n",
+ " if full_items[-len(sub_items):] == sub_items:\n",
+ " if len(full_items) == len(sub_items):\n",
+ " full_match_count += 1\n",
+ " else:\n",
+ " mismatch_count += 1\n",
+ "\n",
+ " mode = \"СУФФИКСЫ\" if check_suffix else \"ПРЕФИКСЫ\"\n",
+ " \n",
+ " if mismatch_count == 0:\n",
+ " print(f\"✅ [{mode}] Все {len(subset_data)} массивов ОК. Полных совпадений: {full_match_count}\")\n",
+ " else:\n",
+ " print(f\"❌ [{mode}] Найдено {mismatch_count} ошибок.\")\n",
+ "\n",
+ "# --- Запуск проверок ---\n",
+ "print(\"Проверка Train сетов (должны быть префиксами):\")\n",
+ "check_prefix_match(old_inter_new, first_sem)\n",
+ "check_prefix_match(old_inter_new, second_sem)\n",
+ "check_prefix_match(old_inter_new, third_sem)\n",
+ "check_prefix_match(old_inter_new, tiger_sem)\n",
+ "\n",
+ "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n",
+ "check_prefix_match(old_inter_new, test_sem, check_suffix=True)\n",
+ "\n",
+ "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n",
+ "check_prefix_match(old_inter_new, test_sem, check_suffix=False)\n",
+ "\n",
+ "check_prefix_match(old_inter_new, all_test_data)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "501fae46",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Part 0 [Base]: 19666 events (2014-03-07 -> 2014-05-13)\n",
+ "Part 1 [Gap]: 10137 events (2014-05-13 -> 2014-06-18)\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "import polars as pl\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from datetime import datetime\n",
+ "labels = [\"Base\", \"Gap\", \"Valid\", \"Test\"]\n",
+ "\n",
+ "# Считаем события в каждом интервале\n",
+ "counts = []\n",
+ "intervals = [1394150400, 1399939200, 1403049600]\n",
+ "for i in range(len(intervals)-1):\n",
+ " start, end = intervals[i], intervals[i+1]\n",
+ " \n",
+ " q = df.lazy().explode([\"timestamps\"])\n",
+ " if end is not None:\n",
+ " q = q.filter((pl.col(\"timestamps\") >= start) & (pl.col(\"timestamps\") < end))\n",
+ " else:\n",
+ " q = q.filter(pl.col(\"timestamps\") >= start)\n",
+ " \n",
+ " count = q.select(pl.len()).collect().item()\n",
+ " counts.append(count)\n",
+ " \n",
+ " end_str = datetime.fromtimestamp(end).strftime('%Y-%m-%d') if end else \"Inf\"\n",
+ " start_str = datetime.fromtimestamp(start).strftime('%Y-%m-%d') if start > 0 else \"Start\"\n",
+ " \n",
+ " print(f\"Part {i} [{labels[i]}]: {count} events ({start_str} -> {end_str})\")\n",
+ "\n",
+ "# 3. Гистограмма распределения событий во времени\n",
+ "all_timestamps = df.select(pl.col(\"timestamps\").explode()).to_series().to_numpy()\n",
+ "\n",
+ "plt.figure(figsize=(12, 6))\n",
+ "plt.hist(all_timestamps, bins=100, color='skyblue', alpha=0.7, label='Events')\n",
+ "\n",
+ "# Рисуем линии отсечек\n",
+ "colors = ['red', 'orange', 'green']\n",
+ "for cutoff, color, label in zip(intervals, colors, labels[1:]):\n",
+ " plt.axvline(x=cutoff, color=color, linestyle='--', linewidth=2, label=f'Cutoff: {label}')\n",
+ "\n",
+ "plt.title(\"Распределение взаимодействий во времени\")\n",
+ "plt.xlabel(\"Timestamp\")\n",
+ "plt.ylabel(\"Количество событий\")\n",
+ "plt.legend()\n",
+ "plt.grid(True, alpha=0.3)\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/sigir/yambda_processing/YambdaDatasetProcessing.ipynb b/sigir/yambda_processing/YambdaDatasetProcessing.ipynb
new file mode 100644
index 0000000..c36af65
--- /dev/null
+++ b/sigir/yambda_processing/YambdaDatasetProcessing.ipynb
@@ -0,0 +1,640 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "SbkKok0dfjjS"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.12/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
+ "source": [
+ "from collections import defaultdict, Counter\n",
+ "from typing import Any, Dict, List, Optional, Tuple\n",
+ "\n",
+ "from datasets import load_dataset\n",
+ "\n",
+ "import numpy as np\n",
+ "\n",
+ "import polars as pl\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.utils.data import Dataset, DataLoader"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "gwwdsnwBfjjT"
+ },
+ "source": [
+ "## 🛠️ Подготовка данных"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "viKiSaEKfjjT",
+ "outputId": "6229cbba-dc3b-4d15-a8e4-ac08e4e187d6"
+ },
+ "outputs": [],
+ "source": [
+ "format = 'sequential'\n",
+ "size = '50m'\n",
+ "events = 'listens'\n",
+ "# listens_data = load_dataset('yandex/yambda', data_dir=f'{format}/{size}', data_files=f'{events}.parquet')\n",
+ "# yambda_df = pl.from_arrow(listens_data['train'].data.table)\n",
+ "yambda_df = pl.read_parquet(\"/home/jovyan/yambda_sequential_50m/sequential/50m/listens.parquet\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "VNanksDRfjjT",
+ "outputId": "e118e2b4-0076-475d-9104-5e1565dab7d9"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "✅ test_yambda_data_loading: OK\n"
+ ]
+ }
+ ],
+ "source": [
+ "def test_yambda_data_loading():\n",
+ " assert isinstance(yambda_df, pl.DataFrame), 'yambda_df должен быть Polars DataFrame'\n",
+ " assert yambda_df.shape == (9238, 6), f'Неправильный размер: {yambda_df.shape}'\n",
+ "\n",
+ " expected_cols = {'uid', 'timestamp', 'item_id', 'is_organic', 'played_ratio_pct', 'track_length_seconds'}\n",
+ " assert set(yambda_df.columns) == expected_cols, f'Неправильные колонки: {yambda_df.columns}'\n",
+ "\n",
+ " assert yambda_df['item_id'].dtype == pl.List(pl.UInt32), 'item_id должен быть List[UInt32]'\n",
+ " assert yambda_df['timestamp'].dtype == pl.List(pl.UInt32), 'timestamp должен быть List[UInt32]'\n",
+ "\n",
+ " assert yambda_df['item_id'].list.len().min() > 0, 'Есть пустые истории'\n",
+ "\n",
+ " print('✅ test_yambda_data_loading: OK')\n",
+ "\n",
+ "test_yambda_data_loading()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 527
+ },
+ "id": "q33EG4wlc8ev",
+ "outputId": "01d03740-713e-46e8-d8c5-81ebf6b71546"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(shape: (5, 6)\n",
+ " ┌─────┬─────────────────────┬────────────┬─────────────┬─────────────────────┬─────────────────────┐\n",
+ " │ uid ┆ timestamp ┆ item_id ┆ is_organic ┆ played_ratio_pct ┆ track_length_second │\n",
+ " │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ s │\n",
+ " │ u32 ┆ list[u32] ┆ list[u32] ┆ list[u8] ┆ list[u16] ┆ --- │\n",
+ " │ ┆ ┆ ┆ ┆ ┆ list[u32] │\n",
+ " ╞═════╪═════════════════════╪════════════╪═════════════╪═════════════════════╪═════════════════════╡\n",
+ " │ 100 ┆ [39420, 39420, … ┆ [8326270, ┆ [0, 0, … 0] ┆ [100, 100, … 100] ┆ [170, 105, … 165] │\n",
+ " │ ┆ 25966140] ┆ 1441281, … ┆ ┆ ┆ │\n",
+ " │ ┆ ┆ 4734787] ┆ ┆ ┆ │\n",
+ " │ 200 ┆ [14329075, ┆ [3285270, ┆ [1, 1, … 1] ┆ [9, 28, … 100] ┆ [170, 170, … 145] │\n",
+ " │ ┆ 14329075, … ┆ 5253582, … ┆ ┆ ┆ │\n",
+ " │ ┆ 2545672… ┆ 3778807] ┆ ┆ ┆ │\n",
+ " │ 300 ┆ [54090, 54100, … ┆ [618910, ┆ [1, 1, … 1] ┆ [2, 4, … 15] ┆ [270, 130, … 210] │\n",
+ " │ ┆ 25907225] ┆ 8793425, … ┆ ┆ ┆ │\n",
+ " │ ┆ ┆ 9286415] ┆ ┆ ┆ │\n",
+ " │ 500 ┆ [22695440, ┆ [6417502, ┆ [0, 0, … 1] ┆ [100, 37, … 13] ┆ [225, 210, … 230] │\n",
+ " │ ┆ 22695690, … ┆ 6896222, … ┆ ┆ ┆ │\n",
+ " │ ┆ 2486145… ┆ 4077285] ┆ ┆ ┆ │\n",
+ " │ 600 ┆ [1329190, 1329405, ┆ [8077497, ┆ [0, 0, … 0] ┆ [100, 100, … 100] ┆ [245, 215, … 205] │\n",
+ " │ ┆ … 25997540] ┆ 1865247, … ┆ ┆ ┆ │\n",
+ " │ ┆ ┆ 6481452] ┆ ┆ ┆ │\n",
+ " └─────┴─────────────────────┴────────────┴─────────────┴─────────────────────┴─────────────────────┘,\n",
+ " shape: (0, 6)\n",
+ " ┌─────┬───────────┬───────────┬────────────┬──────────────────┬──────────────────────┐\n",
+ " │ uid ┆ timestamp ┆ item_id ┆ is_organic ┆ played_ratio_pct ┆ track_length_seconds │\n",
+ " │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
+ " │ u32 ┆ list[u32] ┆ list[u32] ┆ list[u8] ┆ list[u16] ┆ list[u32] │\n",
+ " ╞═════╪═══════════╪═══════════╪════════════╪══════════════════╪══════════════════════╡\n",
+ " └─────┴───────────┴───────────┴────────────┴──────────────────┴──────────────────────┘)"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "(yambda_df.head(), yambda_df.filter(yambda_df['timestamp'].list.len() != yambda_df['item_id'].list.len()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "id": "-9ou8IARfjjT"
+ },
+ "outputs": [
+ {
+ "ename": "ColumnNotFoundError",
+ "evalue": "'explode' on column: 'is_organic' is invalid\n\nSchema at this point: Schema:\nname: _idx, field: UInt32\nname: uid, field: UInt32\nname: timestamp, field: List(UInt32)\nname: item_id, field: List(UInt32)\n",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mColumnNotFoundError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[7], line 5\u001b[0m\n\u001b[1;32m 1\u001b[0m yambda_df \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 2\u001b[0m \u001b[43myambda_df\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilter\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpl\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcol\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43muid\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m%\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m200\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m==\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# надо убрать\u001b[39;49;00m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwith_row_index\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m_idx\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m----> 5\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexplode\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtimestamp\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mitem_id\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mis_organic\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 9\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mplayed_ratio_pct\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 10\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrack_length_seconds\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[43m \u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 12\u001b[0m \u001b[38;5;241m.\u001b[39mfilter(\n\u001b[1;32m 13\u001b[0m (pl\u001b[38;5;241m.\u001b[39mcol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mis_organic\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m) \u001b[38;5;241m&\u001b[39m\n\u001b[1;32m 14\u001b[0m (pl\u001b[38;5;241m.\u001b[39mcol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mplayed_ratio_pct\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m50\u001b[39m)\n\u001b[1;32m 15\u001b[0m )\n\u001b[1;32m 16\u001b[0m \u001b[38;5;241m.\u001b[39mgroup_by([\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_idx\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124muid\u001b[39m\u001b[38;5;124m'\u001b[39m], maintain_order\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[1;32m 17\u001b[0m \u001b[38;5;241m.\u001b[39magg([\n\u001b[1;32m 18\u001b[0m pl\u001b[38;5;241m.\u001b[39mcol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtimestamp\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m 19\u001b[0m pl\u001b[38;5;241m.\u001b[39mcol(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mitem_id\u001b[39m\u001b[38;5;124m'\u001b[39m),\n\u001b[1;32m 20\u001b[0m ])\n\u001b[1;32m 21\u001b[0m \u001b[38;5;241m.\u001b[39mdrop(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_idx\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 22\u001b[0m )\n",
+ "File \u001b[0;32m/usr/local/lib/python3.12/dist-packages/polars/dataframe/frame.py:8072\u001b[0m, in \u001b[0;36mDataFrame.explode\u001b[0;34m(self, columns, *more_columns)\u001b[0m\n\u001b[1;32m 8015\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mexplode\u001b[39m(\n\u001b[1;32m 8016\u001b[0m \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m 8017\u001b[0m columns: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m|\u001b[39m Expr \u001b[38;5;241m|\u001b[39m Sequence[\u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m|\u001b[39m Expr],\n\u001b[1;32m 8018\u001b[0m \u001b[38;5;241m*\u001b[39mmore_columns: \u001b[38;5;28mstr\u001b[39m \u001b[38;5;241m|\u001b[39m Expr,\n\u001b[1;32m 8019\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m DataFrame:\n\u001b[1;32m 8020\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 8021\u001b[0m \u001b[38;5;124;03m Explode the dataframe to long format by exploding the given columns.\u001b[39;00m\n\u001b[1;32m 8022\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 8070\u001b[0m \u001b[38;5;124;03m └─────────┴─────────┘\u001b[39;00m\n\u001b[1;32m 8071\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 8072\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlazy\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexplode\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcolumns\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mmore_columns\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect\u001b[49m\u001b[43m(\u001b[49m\u001b[43m_eager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m/usr/local/lib/python3.12/dist-packages/polars/lazyframe/frame.py:2053\u001b[0m, in \u001b[0;36mLazyFrame.collect\u001b[0;34m(self, type_coercion, predicate_pushdown, projection_pushdown, simplify_expression, slice_pushdown, comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, collapse_joins, no_optimization, streaming, engine, background, _eager, **_kwargs)\u001b[0m\n\u001b[1;32m 2051\u001b[0m \u001b[38;5;66;03m# Only for testing purposes\u001b[39;00m\n\u001b[1;32m 2052\u001b[0m callback \u001b[38;5;241m=\u001b[39m _kwargs\u001b[38;5;241m.\u001b[39mget(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpost_opt_callback\u001b[39m\u001b[38;5;124m\"\u001b[39m, callback)\n\u001b[0;32m-> 2053\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m wrap_df(\u001b[43mldf\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcallback\u001b[49m\u001b[43m)\u001b[49m)\n",
+ "\u001b[0;31mColumnNotFoundError\u001b[0m: 'explode' on column: 'is_organic' is invalid\n\nSchema at this point: Schema:\nname: _idx, field: UInt32\nname: uid, field: UInt32\nname: timestamp, field: List(UInt32)\nname: item_id, field: List(UInt32)\n"
+ ]
+ }
+ ],
+ "source": [
+ "yambda_df = (\n",
+ " yambda_df\n",
+ " .filter(pl.col('uid') % 200 == 0) # надо убрать\n",
+ " .with_row_index('_idx')\n",
+ " .explode([\n",
+ " 'timestamp',\n",
+ " 'item_id',\n",
+ " 'is_organic',\n",
+ " 'played_ratio_pct',\n",
+ " 'track_length_seconds',\n",
+ " ])\n",
+ " .filter(\n",
+ " (pl.col('is_organic') == 0) &\n",
+ " (pl.col('played_ratio_pct') >= 50)\n",
+ " )\n",
+ " .group_by(['_idx', 'uid'], maintain_order=True)\n",
+ " .agg([\n",
+ " pl.col('timestamp'),\n",
+ " pl.col('item_id'),\n",
+ " ])\n",
+ " .drop('_idx')\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "4HTturbHfjjU",
+ "outputId": "ea8b1d93-4997-441a-c6e3-2628acd11d7f"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "✅ test_yambda_filtering: OK\n"
+ ]
+ }
+ ],
+ "source": [
+ "def test_yambda_filtering():\n",
+ " assert yambda_df.shape[0] == 4289, \\\n",
+ " f'Неправильное количество пользователей: {yambda_df.shape[0]}'\n",
+ "\n",
+ " expected_columns = {'uid', 'timestamp', 'item_id'}\n",
+ " actual_columns = set(yambda_df.columns)\n",
+ " assert actual_columns == expected_columns, \\\n",
+ " f'Неправильные колонки. Ожидалось: {expected_columns}, получено: {actual_columns}'\n",
+ "\n",
+ " assert yambda_df['timestamp'].dtype == pl.List(pl.UInt32), \\\n",
+ " f\"timestamp должен быть List[UInt32], получено: {yambda_df['timestamp'].dtype}\"\n",
+ " assert yambda_df['item_id'].dtype == pl.List(pl.UInt32), \\\n",
+ " f\"item_id должен быть List[UInt32], получено: {yambda_df['item_id'].dtype}\"\n",
+ "\n",
+ " seq_lengths = yambda_df['item_id'].list.len()\n",
+ " assert seq_lengths.min() >= 1, \\\n",
+ " f'Минимальная длина последовательности должна быть >= 1, получено: {seq_lengths.min()}'\n",
+ " assert seq_lengths.sum() == 7587469, \\\n",
+ " f'Общее количество событий неверно. Ожидалось: 7587469, получено: {seq_lengths.sum()}'\n",
+ "\n",
+ " unique_items = yambda_df.select('item_id').explode('item_id').unique().shape[0]\n",
+ " assert unique_items == 304787, \\\n",
+ " f'Количество уникальных айтемов неверно. Ожидалось: 304787, получено: {unique_items}'\n",
+ "\n",
+ " print('✅ test_yambda_filtering: OK')\n",
+ "\n",
+ "test_yambda_filtering()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Оригинальные эмбеддинги: (7721749, 3)\n",
+ "Колонки: ['item_id', 'embed', 'normalized_embed']\n",
+ "shape: (5, 3)\n",
+ "┌─────────┬─────────────────────────────────┬─────────────────────────────────┐\n",
+ "│ item_id ┆ embed ┆ normalized_embed │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ u32 ┆ list[f64] ┆ list[f64] │\n",
+ "╞═════════╪═════════════════════════════════╪═════════════════════════════════╡\n",
+ "│ 2 ┆ [-1.534035, -0.366767, … 0.999… ┆ [-0.064638, -0.015454, … 0.042… │\n",
+ "│ 3 ┆ [-3.761467, -1.068254, … -2.66… ┆ [-0.163937, -0.046558, … -0.11… │\n",
+ "│ 4 ┆ [2.445533, -2.523603, … -0.536… ┆ [0.076272, -0.078707, … -0.016… │\n",
+ "│ 5 ┆ [0.832846, 0.116125, … -1.4857… ┆ [0.03149, 0.004391, … -0.05617… │\n",
+ "│ 6 ┆ [-2.431483, -0.56872, … 0.0946… ┆ [-0.10345, -0.024197, … 0.0040… │\n",
+ "└─────────┴─────────────────────────────────┴─────────────────────────────────┘\n"
+ ]
+ }
+ ],
+ "source": [
+ "import polars as pl\n",
+ "import pandas as pd\n",
+ "import pickle\n",
+ "\n",
+ "# === 1. Загрузить оригинальные embeddings ===\n",
+ "embeddings_path = \"/home/jovyan/yambda_embeddings/embeddings.parquet\"\n",
+ "emb_df = pl.read_parquet(embeddings_path)\n",
+ "\n",
+ "print(f\"Оригинальные эмбеддинги: {emb_df.shape}\")\n",
+ "print(f\"Колонки: {emb_df.columns}\")\n",
+ "print(emb_df.head())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Валидные item_id: 7721749\n",
+ "Было строк: 4289\n",
+ "Стало строк: 4138\n"
+ ]
+ }
+ ],
+ "source": [
+ "valid_item_ids = set(emb_df['item_id'].to_list())\n",
+ "print(f\"\\nВалидные item_id: {len(valid_item_ids)}\")\n",
+ "valid_ids_pl = pl.Series(list(valid_item_ids))\n",
+ "\n",
+ "valid_item_ids = set(emb_df['item_id'].to_list())\n",
+ "valid_ids_pl = pl.Series(list(valid_item_ids))\n",
+ "\n",
+ "yambda_df_filtered = (\n",
+ " yambda_df\n",
+ " .with_columns(\n",
+ " pl.col(\"item_id\").list.eval(\n",
+ " pl.when(pl.element().is_in(valid_ids_pl))\n",
+ " .then(pl.int_range(pl.len()))\n",
+ " .otherwise(None)\n",
+ " ).list.drop_nulls().alias(\"valid_indices\")\n",
+ " )\n",
+ " .with_columns([\n",
+ " pl.col(\"item_id\").list.gather(pl.col(\"valid_indices\")),\n",
+ " pl.col(\"timestamp\").list.gather(pl.col(\"valid_indices\"))\n",
+ " ])\n",
+ " .drop(\"valid_indices\")\n",
+ " .filter(pl.col(\"item_id\").list.len() > 0)\n",
+ " .rename({\"item_id\": \"item_ids\", \"timestamp\": \"timestamps\"})\n",
+ ")\n",
+ "yambda_df_filtered = yambda_df_filtered.filter(yambda_df_filtered['item_ids'].list.len() >= 5)\n",
+ "\n",
+ "print(f\"Было строк: {yambda_df.shape[0]}\")\n",
+ "print(f\"Стало строк: {yambda_df_filtered.shape[0]}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "3️⃣ Получите все уникальные ID треков из датасета и создайте маппинг: старый_id - новый_id, где новый_id находится в диапазоне от 0 до N - 1.\n",
+ "\n",
+ "Модели глубокого обучения требуют, чтобы категориальные признаки (в нашем случае ID треков) были представлены целыми числами в диапазоне от 0 до N-1, где N — количество уникальных треков. Датасет Yambda содержит оригинальные ID треков, которые могут быть разреженными (например, [100, 5000, 7, 12000, ...]) — это неэффективно для embedding-таблиц."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "unique_items = (\n",
+ " yambda_df_filtered\n",
+ " .select('item_ids')\n",
+ " .explode('item_ids')\n",
+ " .unique()\n",
+ " .sort('item_ids')\n",
+ ").with_row_index('new_item_ids')\n",
+ "\n",
+ "\n",
+ "item_mapping = dict(zip(unique_items['item_ids'], unique_items['new_item_ids']))\n",
+ "\n",
+ "\n",
+ "yambda_df_filtered = yambda_df_filtered.with_columns([\n",
+ " pl.col('item_ids')\n",
+ " .map_elements(\n",
+ " lambda items: [item_mapping[item] for item in items],\n",
+ " return_dtype=pl.List(pl.UInt32)\n",
+ " )\n",
+ " .alias('item_ids')\n",
+ "])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "✅ test_item_mapping: OK\n"
+ ]
+ }
+ ],
+ "source": [
+ "def test_item_mapping():\n",
+ " assert unique_items.shape == (292865, 2), f'Неправильный размер unique_items: {unique_items.shape}'\n",
+ " assert set(unique_items.columns) == {'new_item_ids', 'item_ids'}, 'Неправильные колонки unique_items'\n",
+ "\n",
+ " assert len(item_mapping) == 292865, f'Неправильный размер item_mapping: {len(item_mapping)}'\n",
+ " assert item_mapping[50] == 0 and item_mapping[175] == 1 and item_mapping[195] == 2, \\\n",
+ " 'Неверные первые маппинги'\n",
+ "\n",
+ " new_ids = unique_items['new_item_ids']\n",
+ " assert new_ids.min() == 0 and new_ids.max() == 292864, 'new_item_id должны быть в [0, 292865,]'\n",
+ "\n",
+ " all_ids = yambda_df_filtered.select('item_ids').explode('item_ids')['item_ids']\n",
+ " assert all_ids.min() == 0 and all_ids.max() == 292864, 'item_id в yambda_df не обновлены'\n",
+ " assert all_ids.n_unique() == 292865, 'Количество уникальных item_id изменилось'\n",
+ "\n",
+ " print('✅ test_item_mapping: OK')\n",
+ "\n",
+ "test_item_mapping()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "Маппинг использует: 292865 уникальных item_id\n",
+ "Переиндексированные эмбеддинги: (292865, 2)\n",
+ "shape: (5, 2)\n",
+ "┌─────────────────────────────────┬─────────┐\n",
+ "│ embedding ┆ item_id │\n",
+ "│ --- ┆ --- │\n",
+ "│ list[f64] ┆ u32 │\n",
+ "╞═════════════════════════════════╪═════════╡\n",
+ "│ [-0.0526, 0.048672, … -0.04217… ┆ 0 │\n",
+ "│ [0.090222, -0.00718, … -0.0862… ┆ 1 │\n",
+ "│ [-0.00822, -0.057882, … 0.2188… ┆ 2 │\n",
+ "│ [-0.107289, -0.034719, … -0.02… ┆ 3 │\n",
+ "│ [0.012762, -0.043315, … 0.1494… ┆ 4 │\n",
+ "└─────────────────────────────────┴─────────┘\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"\\nМаппинг использует: {len(item_mapping)} уникальных item_id\")\n",
+ "\n",
+ "# === 3. Переиндексировать embeddings используя существующий маппинг ===\n",
+ "emb_df_reindexed = emb_df.with_columns(\n",
+ " pl.col('item_id')\n",
+ " .map_elements(\n",
+ " lambda x: item_mapping.get(x, None),\n",
+ " return_dtype=pl.UInt32\n",
+ " )\n",
+ " .alias('new_item_id')\n",
+ ").filter(pl.col('new_item_id').is_not_null()).drop('item_id', 'embed').rename({'new_item_id': 'item_id', 'normalized_embed': 'embedding'})\n",
+ "\n",
+ "print(f\"Переиндексированные эмбеддинги: {emb_df_reindexed.shape}\")\n",
+ "print(emb_df_reindexed.head())\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "✅ test_item_mapping: OK\n"
+ ]
+ }
+ ],
+ "source": [
+ "def test_emb_item_mapping():\n",
+ " all_ids = emb_df_reindexed['item_id']\n",
+ " assert all_ids.min() == 0 and all_ids.max() == 292864, 'item_id в yambda_df не обновлены'\n",
+ " assert all_ids.n_unique() == 292865, 'Количество уникальных item_id изменилось'\n",
+ "\n",
+ " print('✅ test_item_mapping: OK')\n",
+ "\n",
+ "test_emb_item_mapping()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "✓ Сохранены embeddings: /home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet\n"
+ ]
+ }
+ ],
+ "source": [
+ "embeddings_output_parquet_path = \"/home/jovyan/IRec/sigir/yambda_data/yambda_embeddings_reindexed.parquet\"\n",
+ "emb_df_reindexed.write_parquet(embeddings_output_parquet_path)\n",
+ "print(f\"\\n✓ Сохранены embeddings: {embeddings_output_parquet_path}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Тест пройден: все 4138 строк синхронизированы\n"
+ ]
+ }
+ ],
+ "source": [
+ "def test_integrity(df):\n",
+ " bad_rows = df.filter(\n",
+ " (pl.col(\"item_ids\").list.len() != pl.col(\"timestamps\").list.len()) | (pl.col(\"timestamps\").list.len() < 5)\n",
+ " )\n",
+ " \n",
+ " if bad_rows.height > 0:\n",
+ " print(f\"ОШИБКА: {bad_rows.height} строк рассинхронизированы!\")\n",
+ " raise ValueError(\"Рассинхрон массивов!\")\n",
+ " \n",
+ " print(f\"Тест пройден: все {df.height} строк синхронизированы\")\n",
+ "\n",
+ "test_integrity(yambda_df_filtered)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Сохранён filtered yambda_df: /home/jovyan/IRec/sigir/yambda_data/yambda_sequential_50m_filtered_reindexed.parquet\n"
+ ]
+ }
+ ],
+ "source": [
+ "yambda_output_parquet_path = \"/home/jovyan/IRec/sigir/yambda_data/yambda_sequential_50m_filtered_reindexed.parquet\"\n",
+ "yambda_df_filtered.write_parquet(yambda_output_parquet_path)\n",
+ "print(f\"Сохранён filtered yambda_df: {yambda_output_parquet_path}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Сохранён маппинг: /home/jovyan/IRec/sigir/yambda_data/old_to_new_item_id_mapping.json\n"
+ ]
+ }
+ ],
+ "source": [
+ "import json\n",
+ "mapping_output_path = \"/home/jovyan/IRec/sigir/yambda_data/old_to_new_item_id_mapping.json\"\n",
+ "\n",
+ "with open(mapping_output_path, 'w') as f:\n",
+ " json.dump({str(k): v for k, v in item_mapping.items()}, f, indent=2)\n",
+ "\n",
+ "print(f\"Сохранён маппинг: {mapping_output_path}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 44,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (10, 3)| uid | timestamps | item_ids |
|---|
| u32 | list[u32] | list[u32] |
| 600 | [1329190, 1329405, … 25997540] | [252026, 58171, … 201909] |
| 800 | [121100, 121290, … 25977310] | [20844, 198210, … 60455] |
| 1000 | [11335730, 11335925, … 25972225] | [46643, 57592, … 95670] |
| 1400 | [280570, 280735, … 25993315] | [4634, 213798, … 104891] |
| 1600 | [899275, 930305, … 25941890] | [223933, 154424, … 104876] |
| 2000 | [18814620, 18828965, … 25225145] | [137828, 138498, … 19072] |
| 2200 | [10053900, 10054120, … 25948025] | [4923, 231643, … 28122] |
| 2400 | [14246260, 14246390, … 25999860] | [157350, 217652, … 75038] |
| 2600 | [6089640, 6089915, … 25951140] | [9426, 202953, … 140393] |
| 2800 | [19744285, 19744475, … 25894825] | [123607, 291065, … 272888] |
"
+ ],
+ "text/plain": [
+ "shape: (10, 3)\n",
+ "┌──────┬─────────────────────────────────┬────────────────────────────┐\n",
+ "│ uid ┆ timestamps ┆ item_ids │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ u32 ┆ list[u32] ┆ list[u32] │\n",
+ "╞══════╪═════════════════════════════════╪════════════════════════════╡\n",
+ "│ 600 ┆ [1329190, 1329405, … 25997540] ┆ [252026, 58171, … 201909] │\n",
+ "│ 800 ┆ [121100, 121290, … 25977310] ┆ [20844, 198210, … 60455] │\n",
+ "│ 1000 ┆ [11335730, 11335925, … 2597222… ┆ [46643, 57592, … 95670] │\n",
+ "│ 1400 ┆ [280570, 280735, … 25993315] ┆ [4634, 213798, … 104891] │\n",
+ "│ 1600 ┆ [899275, 930305, … 25941890] ┆ [223933, 154424, … 104876] │\n",
+ "│ 2000 ┆ [18814620, 18828965, … 2522514… ┆ [137828, 138498, … 19072] │\n",
+ "│ 2200 ┆ [10053900, 10054120, … 2594802… ┆ [4923, 231643, … 28122] │\n",
+ "│ 2400 ┆ [14246260, 14246390, … 2599986… ┆ [157350, 217652, … 75038] │\n",
+ "│ 2600 ┆ [6089640, 6089915, … 25951140] ┆ [9426, 202953, … 140393] │\n",
+ "│ 2800 ┆ [19744285, 19744475, … 2589482… ┆ [123607, 291065, … 272888] │\n",
+ "└──────┴─────────────────────────────────┴────────────────────────────┘"
+ ]
+ },
+ "execution_count": 44,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "yambda_df_filtered.head(10)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/sigir/yambda_processing/yambda_exps_data.ipynb b/sigir/yambda_processing/yambda_exps_data.ipynb
new file mode 100644
index 0000000..e8f8d8b
--- /dev/null
+++ b/sigir/yambda_processing/yambda_exps_data.ipynb
@@ -0,0 +1,1168 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "e2462a97-6705-44e1-a232-4dd78a5dfc85",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import polars as pl\n",
+ "import json\n",
+ "from typing import List, Dict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "fd38624d-5796-4aa5-929f-7e82c5544f6c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "interactions_output_parquet_path = '/home/jovyan/IRec/sigir/yambda_data/yambda_sequential_50m_filtered_reindexed.parquet'\n",
+ "df = pl.read_parquet(interactions_output_parquet_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "69066941",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def merge_and_save(parts_to_merge, dirr, output_name):\n",
+ " merged = {}\n",
+ " print(f\"Merging {len(parts_to_merge)} files into {output_name}...\")\n",
+ " \n",
+ " for part in parts_to_merge:\n",
+ " # with open(fp, 'r') as f:\n",
+ " # part = json.load(f)\n",
+ " for uid, items in part.items():\n",
+ " if uid not in merged:\n",
+ " merged[uid] = []\n",
+ " merged[uid].extend(items)\n",
+ " \n",
+ " out_path = f\"{dirr}/{output_name}\"\n",
+ " with open(out_path, 'w') as f:\n",
+ " json.dump(merged, f)\n",
+ " print(f\"✓ Done: {out_path} (Users: {len(merged)})\")\n",
+ "\n",
+ "\n",
+ "def merge_and_save_with_filter(parts_to_merge, dirr, output_name, min_history_len=5):\n",
+ " merged = {}\n",
+ " print(f\"Merging {len(parts_to_merge)} files into {output_name} (min len={min_history_len})...\")\n",
+ " \n",
+ " for part in parts_to_merge:\n",
+ " for uid, items in part.items():\n",
+ " if uid not in merged:\n",
+ " merged[uid] = []\n",
+ " merged[uid].extend(items)\n",
+ "\n",
+ " filtered_merged = {}\n",
+ " filtered_count = 0\n",
+ " \n",
+ " for uid, items in merged.items():\n",
+ " if len(items) >= min_history_len:\n",
+ " filtered_merged[uid] = items\n",
+ " else:\n",
+ " filtered_count += 1\n",
+ " \n",
+ " print(f\"Filtered {filtered_count} users with history < {min_history_len}\")\n",
+ " print(f\"Remaining: {len(filtered_merged)} users\")\n",
+ " \n",
+ " out_path = f\"{dirr}/{output_name}\"\n",
+ " with open(out_path, 'w') as f:\n",
+ " json.dump(filtered_merged, f)\n",
+ " print(f\"Done: {out_path} (Users: {len(filtered_merged)})\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "ee127317-66b8-4f22-9109-94bcb8b1f1ae",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def split_session_by_timestamps(\n",
+ " df: pl.DataFrame,\n",
+ " time_cutoffs: List[int],\n",
+ " output_dir: str = None,\n",
+ " return_dicts: bool = True\n",
+ ") -> List[Dict[int, List[int]]]:\n",
+ " \"\"\"\n",
+ " Args:\n",
+ " df: Polars DataFrame с колонками uid, item_ids (list), timestamps (list)\n",
+ " time_cutoffs: Лист временных точек для разбиения\n",
+ " output_dir: Директория для сохранения JSON файлов (опционально)\n",
+ " return_dicts: Возвращать ли словари (как json_data format)\n",
+ " \n",
+ " Возвращает лист словарей в формате {user_id: [item_ids для интервала]}\n",
+ " \"\"\"\n",
+ " \n",
+ " result_dicts = []\n",
+ " \n",
+ " def extract_interval(df_source, start, end=None):\n",
+ " q = df_source.lazy()\n",
+ " q = q.explode([\"item_ids\", \"timestamps\"])\n",
+ " \n",
+ " if end is not None:\n",
+ " q = q.filter(\n",
+ " (pl.col(\"timestamps\") >= start) & \n",
+ " (pl.col(\"timestamps\") < end)\n",
+ " )\n",
+ " else:\n",
+ " q = q.filter(\n",
+ " pl.col(\"timestamps\") >= start\n",
+ " )\n",
+ " \n",
+ " q = q.group_by(\"uid\").agg([\n",
+ " pl.col(\"item_ids\").alias(\"item_ids\")\n",
+ " ]).sort(\"uid\")\n",
+ " \n",
+ " return q.collect()\n",
+ " \n",
+ " intervals = []\n",
+ " current_start = 0\n",
+ " for cutoff in time_cutoffs:\n",
+ " intervals.append((current_start, cutoff))\n",
+ " current_start = cutoff\n",
+ "\n",
+ " intervals.append((current_start, None))\n",
+ "\n",
+ " for start, end in intervals:\n",
+ " subset = extract_interval(df, start, end)\n",
+ "\n",
+ " json_dict = {}\n",
+ " for user_id, item_ids in subset.iter_rows():\n",
+ " json_dict[user_id] = item_ids\n",
+ " \n",
+ " result_dicts.append(json_dict)\n",
+ "\n",
+ " if output_dir:\n",
+ " if end is not None:\n",
+ " filename = f\"inter_new_[{start}_{end}).json\"\n",
+ " else:\n",
+ " filename = f\"inter_new_[{start}_inf).json\"\n",
+ " \n",
+ " filepath = f\"{output_dir}/{filename}\"\n",
+ " with open(filepath, 'w') as f:\n",
+ " json.dump(json_dict, f, indent=2)\n",
+ " \n",
+ " print(f\"✓ Сохранено: {filepath}\")\n",
+ " \n",
+ " return result_dicts"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "6cff8e7b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 3)| uid | timestamps | item_ids |
|---|
| u32 | list[u32] | list[u32] |
| 600 | [1329190, 1329405, … 25997540] | [252026, 58171, … 201909] |
| 800 | [121100, 121290, … 25977310] | [20844, 198210, … 60455] |
| 1000 | [11335730, 11335925, … 25972225] | [46643, 57592, … 95670] |
| 1400 | [280570, 280735, … 25993315] | [4634, 213798, … 104891] |
| 1600 | [899275, 930305, … 25941890] | [223933, 154424, … 104876] |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌──────┬─────────────────────────────────┬────────────────────────────┐\n",
+ "│ uid ┆ timestamps ┆ item_ids │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ u32 ┆ list[u32] ┆ list[u32] │\n",
+ "╞══════╪═════════════════════════════════╪════════════════════════════╡\n",
+ "│ 600 ┆ [1329190, 1329405, … 25997540] ┆ [252026, 58171, … 201909] │\n",
+ "│ 800 ┆ [121100, 121290, … 25977310] ┆ [20844, 198210, … 60455] │\n",
+ "│ 1000 ┆ [11335730, 11335925, … 2597222… ┆ [46643, 57592, … 95670] │\n",
+ "│ 1400 ┆ [280570, 280735, … 25993315] ┆ [4634, 213798, … 104891] │\n",
+ "│ 1600 ┆ [899275, 930305, … 25941890] ┆ [223933, 154424, … 104876] │\n",
+ "└──────┴─────────────────────────────────┴────────────────────────────┘"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "901e7400",
+ "metadata": {},
+ "source": [
+ "# QUANTILE CUTOFF"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "8c691891",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_quantile_cutoffs(df, num_parts=4, base_ratio=None):\n",
+ " \"\"\"\n",
+ " Считает cutoffs так, чтобы разбить данные на части.\n",
+ " \n",
+ " Args:\n",
+ " num_parts: На сколько частей делить \"хвост\" истории.\n",
+ " base_ratio: Какую долю данных отдать в Base (самую первую часть). \n",
+ " Если None, делит всё поровну.\n",
+ " \"\"\"\n",
+ " # Достаем все таймстемпы в один плоский массив\n",
+ " # Это может занять память, если данных очень много (>100M), но для Beauty (2M) это ок\n",
+ " all_ts = df.select(pl.col(\"timestamps\").explode()).to_series().sort()\n",
+ " total_events = len(all_ts)\n",
+ " \n",
+ " print(f\"Всего событий: {total_events}\")\n",
+ " \n",
+ " cutoffs = []\n",
+ " \n",
+ " if base_ratio:\n",
+ " # Base занимает X% (например 80%), а остаток делим поровну на 3 части (Valid, Gap, Test)\n",
+ " # Остаток = 1 - base_ratio\n",
+ " # Каждая малая часть = (1 - base_ratio) / num_parts_tail\n",
+ " \n",
+ " base_idx = int(total_events * base_ratio)\n",
+ " cutoffs.append(all_ts[base_idx]) # Первый cutoff отделяет Base\n",
+ " \n",
+ " remaining_events = total_events - base_idx\n",
+ " part_size = remaining_events // num_parts # Делим остаток на 3 части (P1, P2, P3)\n",
+ " \n",
+ " current_idx = base_idx\n",
+ " for _ in range(num_parts-1): # Нам нужно еще 2 границы, чтобы получить 3 части\n",
+ " current_idx += part_size\n",
+ " cutoffs.append(all_ts[current_idx])\n",
+ " \n",
+ " else:\n",
+ " # Сценарий: Просто делим всё на N равных частей\n",
+ " step = total_events // num_parts\n",
+ " for i in range(1, num_parts):\n",
+ " idx = i * step\n",
+ " cutoffs.append(all_ts[idx])\n",
+ " \n",
+ " return cutoffs\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "13c1466f",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Всего событий: 7371990\n",
+ "\n",
+ "--- Новые Cutoffs (по количеству событий) ---\n",
+ "Cutoffs: [22138015, 23136375, 24137410, 25093085]\n",
+ "[0, 22138015, 23136375, 24137410, 25093085, None]\n"
+ ]
+ }
+ ],
+ "source": [
+ "equal_event_cutoffs = get_quantile_cutoffs(df, num_parts=4, base_ratio=0.8)\n",
+ "\n",
+ "print(\"\\n--- Новые Cutoffs (по количеству событий) ---\")\n",
+ "print(f\"Cutoffs: {equal_event_cutoffs}\")\n",
+ "\n",
+ "# Проверка распределения\n",
+ "intervals_eq = [0] + equal_event_cutoffs + [None]\n",
+ "print(intervals_eq)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "4e7f7b46",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "✓ Сохранено: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/raw/inter_new_[0_22138015).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/raw/inter_new_[22138015_24137410).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/raw/inter_new_[24137410_25093085).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/raw/inter_new_[25093085_inf).json\n",
+ "0 Base 3813 5897592 \n",
+ "1 Gap 3315 737198 \n",
+ "2 Valid 3120 368599 \n",
+ "3 Test 3154 368601 \n"
+ ]
+ }
+ ],
+ "source": [
+ "new_split_files = split_session_by_timestamps(\n",
+ " df, \n",
+ " [22138015, 24137410, 25093085], \n",
+ " output_dir=\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/raw\"\n",
+ ")\n",
+ "\n",
+ "names = [\"Base\", \"Gap\", \"Valid\", \"Test\"]\n",
+ "for i, d in enumerate(new_split_files):\n",
+ " num_users = len(d)\n",
+ " \n",
+ " num_events = sum(len(items) for items in d.values())\n",
+ " \n",
+ " print(f\"{i:<10} {names[i]:<10} {num_users:<10} {num_events:<10}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "82fd2bca",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Merging 2 files into exp_4_0.9_inter_tiger_train.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json (Users: 4016)\n",
+ "Merging 2 files into exp_4-1_0.9_inter_semantics_train.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json (Users: 4016)\n",
+ "Merging 1 files into exp_4-2_0.8_inter_semantics_train.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json (Users: 3813)\n",
+ "Merging 3 files into exp_4-3_0.95_inter_semantics_train.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-3_0.95_inter_semantics_train.json (Users: 4118)\n",
+ "Merging 1 files into test_set.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/test_set.json (Users: 3154)\n",
+ "Merging 1 files into valid_set.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/valid_set.json (Users: 3120)\n",
+ "Merging 4 files into all_set.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json (Users: 4138)\n",
+ "All done!\n"
+ ]
+ }
+ ],
+ "source": [
+ "EXP_DIR = \"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps\"\n",
+ "\n",
+ "base_p, gap_p, valid_p, test_p = new_split_files[0], new_split_files[1], new_split_files[2], new_split_files[3]\n",
+ "\n",
+ "# Tiger: base + gap\n",
+ "merge_and_save([base_p, gap_p], EXP_DIR, \"exp_4_0.9_inter_tiger_train.json\")\n",
+ "\n",
+ "# 1. Exp 4.1 (Standard)\n",
+ "# Semantics: base + gap (Всё кроме валидации и теста)\n",
+ "merge_and_save([base_p, gap_p], EXP_DIR, \"exp_4-1_0.9_inter_semantics_train.json\")\n",
+ "\n",
+ "# 2. Exp 4.2 (Short Semantics)\n",
+ "# Semantics: base (Короче на пропуск, без gap)\n",
+ "merge_and_save([base_p], EXP_DIR, \"exp_4-2_0.8_inter_semantics_train.json\")\n",
+ "\n",
+ "# 3. Exp 4.3 (Leak)\n",
+ "# Semantics: base + gap + valid (Видит валидацию)\n",
+ "merge_and_save([base_p, gap_p, valid_p], EXP_DIR, \"exp_4-3_0.95_inter_semantics_train.json\")\n",
+ "\n",
+ "# 4. Test Set (тест всех моделей)\n",
+ "merge_and_save([test_p], EXP_DIR, \"test_set.json\")\n",
+ "\n",
+ "# 4. Valid Set (валидационный набор)\n",
+ "merge_and_save([valid_p], EXP_DIR, \"valid_set.json\")\n",
+ "\n",
+ "# 4. All Set (все данные)\n",
+ "merge_and_save([base_p, gap_p, valid_p, test_p], EXP_DIR, \"all_set.json\")\n",
+ "\n",
+ "print(\"All done!\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "d34b1c55",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-12-11T08:56:58.546300Z",
+ "start_time": "2025-12-11T08:56:58.343394Z"
+ }
+ },
+ "source": [
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json\", 'r') as f:\n",
+ " old_inter_new = json.load(f)\n",
+ "\n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json\", 'r') as ff:\n",
+ " first_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-2_0.8_inter_semantics_train.json\", 'r') as ff:\n",
+ " second_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-3_0.95_inter_semantics_train.json\", 'r') as ff:\n",
+ " third_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4_0.9_inter_tiger_train.json\", 'r') as ff:\n",
+ " tiger_sem = json.load(ff)\n",
+ "\n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/test_set.json\", 'r') as ff:\n",
+ " test_sem = json.load(ff)\n",
+ "\n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json\", 'r') as ff:\n",
+ " all_test_data = json.load(ff)\n",
+ "\n",
+ "def check_prefix_match(full_data, subset_data, check_suffix=False):\n",
+ " \"\"\"\n",
+ " check_suffix=True включит режим проверки суффиксов (для теста).\n",
+ " \"\"\"\n",
+ " mismatch_count = 0\n",
+ " full_match_count = 0\n",
+ "\n",
+ " num_events_full_data = sum(len(items) for items in full_data.values())\n",
+ " num_events_subset_data = sum(len(items) for items in subset_data.values())\n",
+ " print(f\"доля событий всего {(num_events_subset_data/num_events_full_data):.2f}:\")\n",
+ " \n",
+ " for user, sub_items in subset_data.items():\n",
+ " \n",
+ " if user not in full_data:\n",
+ " print(f\"⚠ Юзер {user} не найден в исходном файле!\")\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " full_items = full_data[user]\n",
+ " \n",
+ " if not check_suffix:\n",
+ " if len(sub_items) > len(full_items):\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " if full_items[:len(sub_items)] == sub_items:\n",
+ " if len(full_items) == len(sub_items):\n",
+ " full_match_count += 1\n",
+ " else:\n",
+ " mismatch_count += 1\n",
+ "\n",
+ " else:\n",
+ " if len(sub_items) > len(full_items):\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ "\n",
+ " if full_items[-len(sub_items):] == sub_items:\n",
+ " if len(full_items) == len(sub_items):\n",
+ " full_match_count += 1\n",
+ " else:\n",
+ " mismatch_count += 1\n",
+ "\n",
+ " mode = \"СУФФИКСЫ\" if check_suffix else \"ПРЕФИКСЫ\"\n",
+ " \n",
+ " if mismatch_count == 0:\n",
+ " print(f\"OK [{mode}] Все {len(subset_data)} массивов ОК. Полных совпадений: {full_match_count}\")\n",
+ " else:\n",
+ " print(f\"NOT OK [{mode}] Найдено {mismatch_count} ошибок.\")\n",
+ "\n",
+ "# --- Запуск проверок ---\n",
+ "print(\"Проверка Train сетов (должны быть префиксами):\")\n",
+ "check_prefix_match(old_inter_new, first_sem)\n",
+ "check_prefix_match(old_inter_new, second_sem)\n",
+ "check_prefix_match(old_inter_new, third_sem)\n",
+ "check_prefix_match(old_inter_new, tiger_sem)\n",
+ "\n",
+ "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n",
+ "check_prefix_match(old_inter_new, test_sem, check_suffix=True)\n",
+ "\n",
+ "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n",
+ "check_prefix_match(old_inter_new, test_sem, check_suffix=False)\n",
+ "\n",
+ "check_prefix_match(old_inter_new, all_test_data)\n"
+ ],
+ "outputs": [
+ {
+ "ename": "FileNotFoundError",
+ "evalue": "[Errno 2] No such file or directory: '/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json'",
+ "output_type": "error",
+ "traceback": [
+ "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
+ "\u001B[0;31mFileNotFoundError\u001B[0m Traceback (most recent call last)",
+ "Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[38;5;124;43m/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json\u001B[39;49m\u001B[38;5;124;43m\"\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mr\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m)\u001B[49m \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[1;32m 2\u001B[0m old_inter_new \u001B[38;5;241m=\u001B[39m json\u001B[38;5;241m.\u001B[39mload(f)\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mopen\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/exp_4-1_0.9_inter_semantics_train.json\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124m'\u001B[39m) \u001B[38;5;28;01mas\u001B[39;00m ff:\n",
+ "File \u001B[0;32m~/repositories/ucp-author-centric/ucp-env/lib/python3.9/site-packages/IPython/core/interactiveshell.py:310\u001B[0m, in \u001B[0;36m_modified_open\u001B[0;34m(file, *args, **kwargs)\u001B[0m\n\u001B[1;32m 303\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m file \u001B[38;5;129;01min\u001B[39;00m {\u001B[38;5;241m0\u001B[39m, \u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m2\u001B[39m}:\n\u001B[1;32m 304\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[1;32m 305\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mIPython won\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mt let you open fd=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mfile\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m by default \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 306\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 307\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124myou can use builtins\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m open.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 308\u001B[0m )\n\u001B[0;32m--> 310\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mio_open\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfile\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
+ "\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: '/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps/all_set.json'"
+ ]
+ }
+ ],
+ "execution_count": 1
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "c3a0adf2",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "================================================================================\n",
+ "ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\n",
+ "================================================================================\n",
+ "\n",
+ "[exp_4-1_0.9] Анализ...\n",
+ " Юзеров в сплите: 4,016 / 4,138\n",
+ " ПУСТЫХ сессий: 15\n",
+ " ОБЩИХ ПРОБЛЕМ: 15\n",
+ "\n",
+ "[exp_4-2_0.8] Анализ...\n",
+ " Юзеров в сплите: 3,813 / 4,138\n",
+ " ПУСТЫХ сессий: 22\n",
+ " ОБЩИХ ПРОБЛЕМ: 22\n",
+ "\n",
+ "[exp_4-3_0.95] Анализ...\n",
+ " Юзеров в сплите: 4,118 / 4,138\n",
+ " ПУСТЫХ сессий: 7\n",
+ " ОБЩИХ ПРОБЛЕМ: 7\n",
+ "\n",
+ "[exp_4_0.9_tiger] Анализ...\n",
+ " Юзеров в сплите: 4,016 / 4,138\n",
+ " ПУСТЫХ сессий: 15\n",
+ " ОБЩИХ ПРОБЛЕМ: 15\n",
+ "\n",
+ "[test_set] Анализ...\n",
+ " Юзеров в сплите: 3,154 / 4,138\n",
+ " ПУСТЫХ сессий: 105\n",
+ " ОБЩИХ ПРОБЛЕМ: 105\n"
+ ]
+ }
+ ],
+ "source": [
+ "def check_non_empty_splits(full_data, splits_data, split_names, min_history_len=2):\n",
+ " \"\"\"\n",
+ " Проверяет, что ни одна часть истории пользователя НЕ пустая во всех разбиениях.\n",
+ " \"\"\"\n",
+ " print(\"\\n\" + \"=\"*80)\n",
+ " print(\"ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\")\n",
+ " print(\"=\"*80)\n",
+ " \n",
+ " all_users = set(full_data.keys())\n",
+ " total_issues = 0\n",
+ " \n",
+ " for i in range(len(split_names)):\n",
+ " split_name = split_names[i]\n",
+ " split_data = splits_data[i]\n",
+ " print(f\"\\n[{split_name}] Анализ...\")\n",
+ " \n",
+ " split_users = set(split_data.keys())\n",
+ " empty_sessions = []\n",
+ " \n",
+ " for user, items in split_data.items():\n",
+ " if not items or len(items) < min_history_len:\n",
+ " empty_sessions.append(user)\n",
+ " \n",
+ " issues_count = len(empty_sessions)\n",
+ " total_issues += issues_count\n",
+ " \n",
+ " print(f\" Юзеров в сплите: {len(split_users):,} / {len(all_users):,}\")\n",
+ " print(f\" ПУСТЫХ сессий: {len(empty_sessions)}\")\n",
+ " print(f\" ОБЩИХ ПРОБЛЕМ: {issues_count}\")\n",
+ " \n",
+ " if total_issues == 0:\n",
+ " print(\"\\nВСЕ РАЗБИЕНИЯ БЕЗ ПУСТЫХ СЕССИЙ\")\n",
+ "\n",
+ "split_names = ['exp_4-1_0.9', 'exp_4-2_0.8', 'exp_4-3_0.95', 'exp_4_0.9_tiger', 'test_set']\n",
+ "splits_list = [first_sem, second_sem, third_sem, tiger_sem, test_sem]\n",
+ "\n",
+ "check_non_empty_splits(old_inter_new, splits_list, split_names)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "43aa0142",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Merging 2 files into exp_4_0.9_inter_tiger_train.json (min len=2)...\n",
+ "Filtered 15 users with history < 2\n",
+ "Remaining: 4001 users\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4_0.9_inter_tiger_train.json (Users: 4001)\n",
+ "Merging 2 files into exp_4-1_0.9_inter_semantics_train.json (min len=2)...\n",
+ "Filtered 15 users with history < 2\n",
+ "Remaining: 4001 users\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4-1_0.9_inter_semantics_train.json (Users: 4001)\n",
+ "Merging 1 files into exp_4-2_0.8_inter_semantics_train.json (min len=2)...\n",
+ "Filtered 22 users with history < 2\n",
+ "Remaining: 3791 users\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4-2_0.8_inter_semantics_train.json (Users: 3791)\n",
+ "Merging 3 files into exp_4-3_0.95_inter_semantics_train.json (min len=2)...\n",
+ "Filtered 7 users with history < 2\n",
+ "Remaining: 4111 users\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4-3_0.95_inter_semantics_train.json (Users: 4111)\n",
+ "Merging 1 files into test_set.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/test_set.json (Users: 3154)\n",
+ "Merging 1 files into valid_set.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/valid_set.json (Users: 3120)\n",
+ "Merging 4 files into all_set.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/all_set.json (Users: 4138)\n",
+ "All done!\n"
+ ]
+ }
+ ],
+ "source": [
+ "EXP_DIR_FILTERED = \"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered\"\n",
+ "\n",
+ "base_p, gap_p, valid_p, test_p = new_split_files[0], new_split_files[1], new_split_files[2], new_split_files[3]\n",
+ "\n",
+ "# Tiger: base + gap\n",
+ "merge_and_save_with_filter([base_p, gap_p], EXP_DIR_FILTERED, \"exp_4_0.9_inter_tiger_train.json\", min_history_len=2)\n",
+ "\n",
+ "# 1. Exp 4.1 (Standard)\n",
+ "# Semantics: base + gap (Всё кроме валидации и теста)\n",
+ "merge_and_save_with_filter([base_p, gap_p], EXP_DIR_FILTERED, \"exp_4-1_0.9_inter_semantics_train.json\", min_history_len=2)\n",
+ "\n",
+ "# 2. Exp 4.2 (Short Semantics)\n",
+ "# Semantics: base (Короче на пропуск, без gap)\n",
+ "merge_and_save_with_filter([base_p], EXP_DIR_FILTERED, \"exp_4-2_0.8_inter_semantics_train.json\", min_history_len=2)\n",
+ "\n",
+ "# 3. Exp 4.3 (Leak)\n",
+ "# Semantics: base + gap + valid (Видит валидацию)\n",
+ "merge_and_save_with_filter([base_p, gap_p, valid_p], EXP_DIR_FILTERED, \"exp_4-3_0.95_inter_semantics_train.json\", min_history_len=2)\n",
+ "\n",
+ "# 4. Test Set (тест всех моделей)\n",
+ "merge_and_save([test_p], EXP_DIR_FILTERED, \"test_set.json\")\n",
+ "\n",
+ "# 4. Valid Set (валидационный набор)\n",
+ "merge_and_save([valid_p], EXP_DIR_FILTERED, \"valid_set.json\")\n",
+ "\n",
+ "# 4. All Set (все данные)\n",
+ "merge_and_save([base_p, gap_p, valid_p, test_p], EXP_DIR_FILTERED, \"all_set.json\")\n",
+ "\n",
+ "print(\"All done!\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9060beaa",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Проверка Train сетов (должны быть префиксами):\n",
+ "доля событий всего 0.90:\n",
+ "✅ [ПРЕФИКСЫ] Все 4001 массивов ОК. Полных совпадений: 564\n",
+ "доля событий всего 0.80:\n",
+ "✅ [ПРЕФИКСЫ] Все 3791 массивов ОК. Полных совпадений: 343\n",
+ "доля событий всего 0.95:\n",
+ "✅ [ПРЕФИКСЫ] Все 4111 массивов ОК. Полных совпадений: 984\n",
+ "доля событий всего 0.90:\n",
+ "✅ [ПРЕФИКСЫ] Все 4001 массивов ОК. Полных совпадений: 564\n",
+ "\n",
+ "Проверка Test сета (должен быть суффиксом):\n",
+ "доля событий всего 0.05:\n",
+ "✅ [СУФФИКСЫ] Все 3154 массивов ОК. Полных совпадений: 20\n",
+ "\n",
+ "(Контроль) Проверка Test сета как префикса (должна упасть):\n",
+ "доля событий всего 0.05:\n",
+ "❌ [ПРЕФИКСЫ] Найдено 3134 ошибок.\n",
+ "доля событий всего 1.00:\n",
+ "✅ [ПРЕФИКСЫ] Все 4138 массивов ОК. Полных совпадений: 4138\n",
+ "\n",
+ "================================================================================\n",
+ "ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\n",
+ "================================================================================\n",
+ "\n",
+ "[exp_4-1_0.9] Анализ...\n",
+ " Юзеров в сплите: 4,001 / 4,138\n",
+ " ПУСТЫХ сессий: 0\n",
+ " ОБЩИХ ПРОБЛЕМ: 0\n",
+ "\n",
+ "[exp_4-2_0.8] Анализ...\n",
+ " Юзеров в сплите: 3,791 / 4,138\n",
+ " ПУСТЫХ сессий: 0\n",
+ " ОБЩИХ ПРОБЛЕМ: 0\n",
+ "\n",
+ "[exp_4-3_0.95] Анализ...\n",
+ " Юзеров в сплите: 4,111 / 4,138\n",
+ " ПУСТЫХ сессий: 0\n",
+ " ОБЩИХ ПРОБЛЕМ: 0\n",
+ "\n",
+ "[exp_4_0.9_tiger] Анализ...\n",
+ " Юзеров в сплите: 4,001 / 4,138\n",
+ " ПУСТЫХ сессий: 0\n",
+ " ОБЩИХ ПРОБЛЕМ: 0\n",
+ "\n",
+ "ВСЕ РАЗБИЕНИЯ БЕЗ ПУСТЫХ СЕССИЙ\n"
+ ]
+ }
+ ],
+ "source": [
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_expsx/exp_4-1_0.9_inter_semantics_train.json\", 'r') as ff:\n",
+ " filtered_first_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4-2_0.8_inter_semantics_train.json\", 'r') as ff:\n",
+ " filtered_second_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4-3_0.95_inter_semantics_train.json\", 'r') as ff:\n",
+ " filtered_third_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/exp_4_0.9_inter_tiger_train.json\", 'r') as ff:\n",
+ " filtered_tiger_sem = json.load(ff)\n",
+ "\n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/valid_set.json\", 'r') as ff:\n",
+ " fiiltered_valid_sem = json.load(ff)\n",
+ "\n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/test_set.json\", 'r') as ff:\n",
+ " fiiltered_test_sem = json.load(ff)\n",
+ "\n",
+ "with open(\"/home/jovyan/IRec/data/Yambda/updated_quantile_splits/merged_for_exps_filtered/all_set.json\", 'r') as ff:\n",
+ " filtered_all_test_data = json.load(ff)\n",
+ "\n",
+ "# --- Запуск проверок ---\n",
+ "print(\"Проверка Train сетов (должны быть префиксами):\")\n",
+ "check_prefix_match(filtered_all_test_data, filtered_first_sem)\n",
+ "check_prefix_match(filtered_all_test_data, filtered_second_sem)\n",
+ "check_prefix_match(filtered_all_test_data, filtered_third_sem)\n",
+ "check_prefix_match(filtered_all_test_data, filtered_tiger_sem)\n",
+ "\n",
+ "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n",
+ "check_prefix_match(filtered_all_test_data, test_sem, check_suffix=True)\n",
+ "\n",
+ "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n",
+ "check_prefix_match(filtered_all_test_data, test_sem, check_suffix=False)\n",
+ "\n",
+ "check_prefix_match(filtered_all_test_data, all_test_data)\n",
+ "\n",
+ "split_names = ['exp_4-1_0.9', 'exp_4-2_0.8', 'exp_4-3_0.95', 'exp_4_0.9_tiger']\n",
+ "splits_list_filtered = [filtered_first_sem, filtered_second_sem, filtered_third_sem, filtered_tiger_sem]\n",
+ "\n",
+ "check_non_empty_splits(filtered_all_test_data, splits_list_filtered, split_names, min_history_len = 2)\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "c540c8d5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "для теста и валидации (может упасть и скорее всего упадет)\n",
+ "\n",
+ "================================================================================\n",
+ "ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\n",
+ "================================================================================\n",
+ "\n",
+ "[valid] Анализ...\n",
+ " Юзеров в сплите: 3,120 / 4,138\n",
+ " ПУСТЫХ сессий: 88\n",
+ " ОБЩИХ ПРОБЛЕМ: 88\n",
+ "\n",
+ "[test] Анализ...\n",
+ " Юзеров в сплите: 3,154 / 4,138\n",
+ " ПУСТЫХ сессий: 105\n",
+ " ОБЩИХ ПРОБЛЕМ: 105\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"для теста и валидации (может упасть и скорее всего упадет)\")\n",
+ "vt_split_names = ['valid', 'test']\n",
+ "vt_splits_list_filtered = [fiiltered_valid_sem, test_sem]\n",
+ "\n",
+ "check_non_empty_splits(filtered_all_test_data, vt_splits_list_filtered, vt_split_names, min_history_len = 2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "89efa96e",
+ "metadata": {},
+ "source": [
+ "# Разбиение YAMBDA по неделям"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "28e4ddc8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Cutoffs: [25740785, 25827185, 25913585]\n",
+ "✓ Сохранено: /home/jovyan/IRec/data/Yambda/day-splits/raw/inter_new_[0_25740785).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/data/Yambda/day-splits/raw/inter_new_[25740785_25827185).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/data/Yambda/day-splits/raw/inter_new_[25827185_25913585).json\n",
+ "✓ Сохранено: /home/jovyan/IRec/data/Yambda/day-splits/raw/inter_new_[25913585_inf).json\n",
+ "Part 0 [Base]: 4133 users\n",
+ "Part 1 [day -3]: 1381 users\n",
+ "Part 2 [day -2]: 1350 users\n",
+ "Part 3 [day -1]: 1403 users\n"
+ ]
+ }
+ ],
+ "source": [
+ "global_max_time = df.select(\n",
+ " pl.col(\"timestamps\").explode().max()\n",
+ ").item()\n",
+ "\n",
+ "# 3. Размер окна (неделя)\n",
+ "days_val = 1\n",
+ "window_sec = days_val * 24 * 3600 \n",
+ "\n",
+ "# 4. Три отсечки с конца\n",
+ "cutoff_test_start = global_max_time - window_sec # T - 1w\n",
+ "cutoff_val_start = global_max_time - 2 * window_sec # T - 2w\n",
+ "cutoff_gap_start = global_max_time - 3 * window_sec # T - 3w\n",
+ "\n",
+ "cutoffs = [\n",
+ " int(cutoff_gap_start), # Граница Part 0 | Part 1\n",
+ " int(cutoff_val_start), # Граница Part 1 | Part 2\n",
+ " int(cutoff_test_start) # Граница Part 2 | Part 3\n",
+ "]\n",
+ "\n",
+ "print(f\"Cutoffs: {cutoffs}\")\n",
+ "\n",
+ "split_files = split_session_by_timestamps(\n",
+ " df, \n",
+ " cutoffs, \n",
+ " output_dir=\"/home/jovyan/IRec/data/Yambda/day-splits/raw\"\n",
+ ")\n",
+ "\n",
+ "names = [\"Base\", \"day -3\", \"day -2\", \"day -1\"]\n",
+ "for i, d in enumerate(split_files):\n",
+ " print(f\"Part {i} [{names[i]}]: {len(d)} users\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "8d5b0c22",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Merging 2 files into exp_4_0.9_inter_tiger_train.json (min len=2)...\n",
+ "Filtered 3 users with history < 2\n",
+ "Remaining: 4133 users\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/exp_4_0.9_inter_tiger_train.json (Users: 4133)\n",
+ "Merging 2 files into exp_4-1_0.9_inter_semantics_train.json (min len=2)...\n",
+ "Filtered 3 users with history < 2\n",
+ "Remaining: 4133 users\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/exp_4-1_0.9_inter_semantics_train.json (Users: 4133)\n",
+ "Merging 1 files into exp_4-2_0.8_inter_semantics_train.json (min len=2)...\n",
+ "Filtered 3 users with history < 2\n",
+ "Remaining: 4130 users\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/exp_4-2_0.8_inter_semantics_train.json (Users: 4130)\n",
+ "Merging 3 files into exp_4-3_0.95_inter_semantics_train.json (min len=2)...\n",
+ "Filtered 3 users with history < 2\n",
+ "Remaining: 4133 users\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/exp_4-3_0.95_inter_semantics_train.json (Users: 4133)\n",
+ "Merging 1 files into test_set.json (min len=1)...\n",
+ "Filtered 0 users with history < 1\n",
+ "Remaining: 1403 users\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/test_set.json (Users: 1403)\n",
+ "Merging 1 files into valid_set.json (min len=1)...\n",
+ "Filtered 0 users with history < 1\n",
+ "Remaining: 1350 users\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/valid_set.json (Users: 1350)\n",
+ "Merging 4 files into all_set.json...\n",
+ "✓ Done: /home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered/all_set.json (Users: 4138)\n",
+ "All done!\n"
+ ]
+ }
+ ],
+ "source": [
+ "EXP_DIR_FILTERED = \"/home/jovyan/IRec/data/Yambda/day-splits/merged_for_exps_filtered\"\n",
+ "\n",
+ "base_p, gap_p, valid_p, test_p = split_files[0], split_files[1], split_files[2], split_files[3]\n",
+ "\n",
+ "# Tiger: base + gap\n",
+ "merge_and_save_with_filter([base_p, gap_p], EXP_DIR_FILTERED, \"exp_4_0.9_inter_tiger_train.json\", min_history_len=2)\n",
+ "\n",
+ "# 1. Exp 4.1 (Standard)\n",
+ "# Semantics: base + gap (Всё кроме валидации и теста)\n",
+ "merge_and_save_with_filter([base_p, gap_p], EXP_DIR_FILTERED, \"exp_4-1_0.9_inter_semantics_train.json\", min_history_len=2)\n",
+ "\n",
+ "# 2. Exp 4.2 (Short Semantics)\n",
+ "# Semantics: base (Короче на пропуск, без gap)\n",
+ "merge_and_save_with_filter([base_p], EXP_DIR_FILTERED, \"exp_4-2_0.8_inter_semantics_train.json\", min_history_len=2)\n",
+ "\n",
+ "# 3. Exp 4.3 (Leak)\n",
+ "# Semantics: base + gap + valid (Видит валидацию)\n",
+ "merge_and_save_with_filter([base_p, gap_p, valid_p], EXP_DIR_FILTERED, \"exp_4-3_0.95_inter_semantics_train.json\", min_history_len=2)\n",
+ "\n",
+ "# 4. Test Set (тест всех моделей)\n",
+ "merge_and_save_with_filter([test_p], EXP_DIR_FILTERED, \"test_set.json\", min_history_len=1)\n",
+ "\n",
+ "# 4. Valid Set (валидационный набор)\n",
+ "merge_and_save_with_filter([valid_p], EXP_DIR_FILTERED, \"valid_set.json\", min_history_len=1)\n",
+ "\n",
+ "# 4. All Set (все данные)\n",
+ "merge_and_save([base_p, gap_p, valid_p, test_p], EXP_DIR_FILTERED, \"all_set.json\")\n",
+ "\n",
+ "print(\"All done!\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "c0b9b767",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def check_non_empty_splits(full_data, splits_data, split_names, min_history_len=2):\n",
+ " \"\"\"\n",
+ " Проверяет, что ни одна часть истории пользователя НЕ пустая во всех разбиениях.\n",
+ " \"\"\"\n",
+ " print(\"\\n\" + \"=\"*80)\n",
+ " print(\"ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\")\n",
+ " print(\"=\"*80)\n",
+ " \n",
+ " all_users = set(full_data.keys())\n",
+ " total_issues = 0\n",
+ " \n",
+ " for i in range(len(split_names)):\n",
+ " split_name = split_names[i]\n",
+ " split_data = splits_data[i]\n",
+ " print(f\"\\n[{split_name}] Анализ...\")\n",
+ " \n",
+ " split_users = set(split_data.keys())\n",
+ " empty_sessions = []\n",
+ " \n",
+ " for user, items in split_data.items():\n",
+ " if not items or len(items) < min_history_len:\n",
+ " empty_sessions.append(user)\n",
+ " \n",
+ " issues_count = len(empty_sessions)\n",
+ " total_issues += issues_count\n",
+ " \n",
+ " print(f\" Юзеров в сплите: {len(split_users):,} / {len(all_users):,}\")\n",
+ " print(f\" ПУСТЫХ сессий: {len(empty_sessions)}\")\n",
+ " print(f\" ОБЩИХ ПРОБЛЕМ: {issues_count}\")\n",
+ " \n",
+ " if total_issues == 0:\n",
+ " print(\"\\nВСЕ РАЗБИЕНИЯ БЕЗ ПУСТЫХ СЕССИЙ\")\n",
+ "\n",
+ "def check_prefix_match(full_data, subset_data, check_suffix=False):\n",
+ " \"\"\"\n",
+ " check_suffix=True включит режим проверки суффиксов (для теста).\n",
+ " \"\"\"\n",
+ " mismatch_count = 0\n",
+ " full_match_count = 0\n",
+ " \n",
+ " # Итерируемся по ключам сабсета, так как в full_data может быть больше юзеров\n",
+ " for user, sub_items in subset_data.items():\n",
+ " \n",
+ " # Проверяем есть ли такой юзер в исходнике\n",
+ " if user not in full_data:\n",
+ " print(f\"⚠ Юзер {user} не найден в исходном файле!\")\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " full_items = full_data[user]\n",
+ " \n",
+ " # Логика для проверки ПРЕФИКСА (начало совпадает)\n",
+ " if not check_suffix:\n",
+ " if len(sub_items) > len(full_items):\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " # Сравниваем начало full с sub\n",
+ " if full_items[:len(sub_items)] == sub_items:\n",
+ " if len(full_items) == len(sub_items):\n",
+ " full_match_count += 1\n",
+ " else:\n",
+ " mismatch_count += 1\n",
+ " \n",
+ " # Логика для проверки СУФФИКСА (конец совпадает - для теста)\n",
+ " else:\n",
+ " if len(sub_items) > len(full_items):\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " # Сравниваем конец full с sub\n",
+ " # Срез [-len:] берет последние N элементов\n",
+ " if full_items[-len(sub_items):] == sub_items:\n",
+ " if len(full_items) == len(sub_items):\n",
+ " full_match_count += 1\n",
+ " else:\n",
+ " mismatch_count += 1\n",
+ "\n",
+ " mode = \"СУФФИКСЫ\" if check_suffix else \"ПРЕФИКСЫ\"\n",
+ " \n",
+ " if mismatch_count == 0:\n",
+ " print(f\"✅ [{mode}] Все {len(subset_data)} массивов ОК. Полных совпадений: {full_match_count}\")\n",
+ " else:\n",
+ " print(f\"❌ [{mode}] Найдено {mismatch_count} ошибок.\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "36ac0115",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Проверка Train сетов (должны быть префиксами):\n",
+ "✅ [ПРЕФИКСЫ] Все 4133 массивов ОК. Полных совпадений: 2272\n",
+ "✅ [ПРЕФИКСЫ] Все 4130 массивов ОК. Полных совпадений: 1969\n",
+ "✅ [ПРЕФИКСЫ] Все 4133 массивов ОК. Полных совпадений: 2735\n",
+ "✅ [ПРЕФИКСЫ] Все 4133 массивов ОК. Полных совпадений: 2272\n",
+ "\n",
+ "Проверка Test сета (должен быть суффиксом):\n",
+ "✅ [СУФФИКСЫ] Все 1403 массивов ОК. Полных совпадений: 2\n",
+ "\n",
+ "(Контроль) Проверка Test сета как префикса (должна упасть):\n",
+ "❌ [ПРЕФИКСЫ] Найдено 1401 ошибок.\n",
+ "✅ [ПРЕФИКСЫ] Все 4138 массивов ОК. Полных совпадений: 4138\n"
+ ]
+ }
+ ],
+ "source": [
+ "with open(f\"{EXP_DIR_FILTERED}/exp_4-1_0.9_inter_semantics_train.json\", 'r') as ff:\n",
+ " filtered_first_sem = json.load(ff)\n",
+ " \n",
+ "with open(f\"{EXP_DIR_FILTERED}/exp_4-2_0.8_inter_semantics_train.json\", 'r') as ff:\n",
+ " filtered_second_sem = json.load(ff)\n",
+ " \n",
+ "with open(f\"{EXP_DIR_FILTERED}/exp_4-3_0.95_inter_semantics_train.json\", 'r') as ff:\n",
+ " filtered_third_sem = json.load(ff)\n",
+ " \n",
+ "with open(f\"{EXP_DIR_FILTERED}/exp_4_0.9_inter_tiger_train.json\", 'r') as ff:\n",
+ " filtered_tiger_sem = json.load(ff)\n",
+ "\n",
+ "with open(f\"{EXP_DIR_FILTERED}/valid_set.json\", 'r') as ff:\n",
+ " fiiltered_valid_sem = json.load(ff)\n",
+ "\n",
+ "with open(f\"{EXP_DIR_FILTERED}/test_set.json\", 'r') as ff:\n",
+ " filtered_test_sem = json.load(ff)\n",
+ "\n",
+ "with open(f\"{EXP_DIR_FILTERED}/all_set.json\", 'r') as ff:\n",
+ " filtered_all_test_data = json.load(ff)\n",
+ "\n",
+ "# --- Запуск проверок ---\n",
+ "print(\"Проверка Train сетов (должны быть префиксами):\")\n",
+ "check_prefix_match(filtered_all_test_data, filtered_first_sem)\n",
+ "check_prefix_match(filtered_all_test_data, filtered_second_sem)\n",
+ "check_prefix_match(filtered_all_test_data, filtered_third_sem)\n",
+ "check_prefix_match(filtered_all_test_data, filtered_tiger_sem)\n",
+ "\n",
+ "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n",
+ "check_prefix_match(filtered_all_test_data, filtered_test_sem, check_suffix=True)\n",
+ "\n",
+ "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n",
+ "check_prefix_match(filtered_all_test_data, filtered_test_sem, check_suffix=False)\n",
+ "\n",
+ "check_prefix_match(filtered_all_test_data, filtered_all_test_data)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "2c65331b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "================================================================================\n",
+ "ПРОВЕРКА НА ПУСТЫЕ ЧАСТИ ИСТОРИЙ\n",
+ "================================================================================\n",
+ "\n",
+ "[exp_4-1_0.9] Анализ...\n",
+ " Юзеров в сплите: 4,133 / 4,138\n",
+ " ПУСТЫХ сессий: 0\n",
+ " ОБЩИХ ПРОБЛЕМ: 0\n",
+ "\n",
+ "[exp_4-2_0.8] Анализ...\n",
+ " Юзеров в сплите: 4,130 / 4,138\n",
+ " ПУСТЫХ сессий: 0\n",
+ " ОБЩИХ ПРОБЛЕМ: 0\n",
+ "\n",
+ "[exp_4-3_0.95] Анализ...\n",
+ " Юзеров в сплите: 4,133 / 4,138\n",
+ " ПУСТЫХ сессий: 0\n",
+ " ОБЩИХ ПРОБЛЕМ: 0\n",
+ "\n",
+ "[exp_4_0.9_tiger] Анализ...\n",
+ " Юзеров в сплите: 4,133 / 4,138\n",
+ " ПУСТЫХ сессий: 0\n",
+ " ОБЩИХ ПРОБЛЕМ: 0\n",
+ "\n",
+ "ВСЕ РАЗБИЕНИЯ БЕЗ ПУСТЫХ СЕССИЙ\n"
+ ]
+ }
+ ],
+ "source": [
+ "split_names = ['exp_4-1_0.9', 'exp_4-2_0.8', 'exp_4-3_0.95', 'exp_4_0.9_tiger']\n",
+ "splits_list_filtered = [filtered_first_sem, filtered_second_sem, filtered_third_sem, filtered_tiger_sem]\n",
+ "\n",
+ "check_non_empty_splits(filtered_all_test_data, splits_list_filtered, split_names, min_history_len = 2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "f596db64",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "0 Base 4133 7264231 \n",
+ "1 day -3 1381 36676 \n",
+ "2 day -2 1350 35128 \n",
+ "3 day -1 1403 35955 \n"
+ ]
+ }
+ ],
+ "source": [
+ "filtered_all_test_data.keys()\n",
+ "for i, d in enumerate(split_files):\n",
+ " num_users = len(d)\n",
+ " \n",
+ " num_events = sum(len(items) for items in d.values())\n",
+ " \n",
+ " print(f\"{i:<10} {names[i]:<10} {num_users:<10} {num_events:<10}\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/src/irec/callbacks/stopping.py b/src/irec/callbacks/stopping.py
index 3d1405f..bbe091f 100644
--- a/src/irec/callbacks/stopping.py
+++ b/src/irec/callbacks/stopping.py
@@ -44,14 +44,18 @@ def after_step(self, runner: Runner, context: RunnerContext):
metric = context.metrics[self._metric]
if self._best_metric is None:
self._best_metric = metric
- torch.save(runner.model.state_dict(), f'{self._model_path}_best_{round(self._best_metric, 4)}.pth')
+ save_path = f'{self._model_path}_best_{round(self._best_metric, 4)}.pth'
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ torch.save(runner.model.state_dict(), save_path)
else:
if (self._minimize and metric < self._best_metric) or (not self._minimize and metric > self._best_metric):
self._wait = 0
old_metric = self._best_metric
self._best_metric = metric
# Saving new model
- torch.save(runner.model.state_dict(), f'{self._model_path}_best_{round(self._best_metric, 4)}.pth')
+ save_path = f'{self._model_path}_best_{round(self._best_metric, 4)}.pth'
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ torch.save(runner.model.state_dict(), save_path)
# Deleting old model
if str(round(self._best_metric, 4)) != str(round(old_metric, 4)):
os.remove(f'{self._model_path}_best_{round(old_metric, 4)}.pth')