diff --git a/scripts/plum/cooc_data.py b/scripts/plum/cooc_data.py
index b11e6f0..50f2bdd 100644
--- a/scripts/plum/cooc_data.py
+++ b/scripts/plum/cooc_data.py
@@ -13,21 +13,15 @@ class CoocMappingDataset:
def __init__(
self,
train_sampler,
- validation_sampler,
- test_sampler,
num_items,
- max_sequence_length,
cooccur_counter_mapping=None
):
self._train_sampler = train_sampler
- self._validation_sampler = validation_sampler
- self._test_sampler = test_sampler
self._num_items = num_items
- self._max_sequence_length = max_sequence_length
self._cooccur_counter_mapping = cooccur_counter_mapping
@classmethod
- def create(cls, inter_json_path, max_sequence_length, sampler_type, window_size):
+ def create(cls, inter_json_path, window_size):
max_item_id = 0
train_dataset, validation_dataset, test_dataset = [], [], []
@@ -43,31 +37,59 @@ def create(cls, inter_json_path, max_sequence_length, sampler_type, window_size)
'user.ids': [user_id],
'item.ids': item_ids[:-2],
})
- validation_dataset.append({
- 'user.ids': [user_id],
- 'item.ids': item_ids[:-1],
- })
- test_dataset.append({
- 'user.ids': [user_id],
- 'item.ids': item_ids,
- })
cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size)
logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}')
train_sampler = train_dataset
- validation_sampler = validation_dataset
- test_sampler = test_dataset
return cls(
train_sampler=train_sampler,
- validation_sampler=validation_sampler,
- test_sampler=test_sampler,
num_items=max_item_id + 1,
- max_sequence_length=max_sequence_length,
cooccur_counter_mapping=cooccur_counter_mapping
)
+ @classmethod
+ def create_from_split_part(
+ cls,
+ train_inter_json_path,
+ window_size
+ ):
+
+ max_item_id = 0
+ train_dataset = []
+
+ with open(train_inter_json_path, 'r') as f:
+ train_interactions = json.load(f)
+
+ # Обрабатываем TRAIN
+ for user_id_str, item_ids in train_interactions.items():
+ user_id = int(user_id_str)
+ if item_ids:
+ max_item_id = max(max_item_id, max(item_ids))
+
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': item_ids,
+ })
+
+ logger.debug(f'Train: {len(train_dataset)} users')
+ logger.debug(f'Max item ID: {max_item_id}')
+
+ cooccur_counter_mapping = cls.build_cooccur_counter_mapping(
+ train_dataset,
+ window_size=window_size
+ )
+
+ logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items')
+
+ return cls(
+ train_sampler=train_dataset,
+ num_items=max_item_id + 1,
+ cooccur_counter_mapping=cooccur_counter_mapping
+ )
+
+
@staticmethod
def build_cooccur_counter_mapping(train_dataset, window_size): #TODO передавать время и по нему строить окно
cooccur_counts = defaultdict(Counter)
@@ -80,16 +102,6 @@ def build_cooccur_counter_mapping(train_dataset, window_size): #TODO перед
cooccur_counts[item_i][items[j]] += 1
return cooccur_counts
- def get_datasets(self):
- return self._train_sampler, self._validation_sampler, self._test_sampler
-
- @property
- def num_items(self):
- return self._num_items
-
- @property
- def max_sequence_length(self):
- return self._max_sequence_length
@property
def cooccur_counter_mapping(self):
diff --git a/scripts/plum/infer_default.py b/scripts/plum/infer_default.py
index af8df34..b15fb6d 100644
--- a/scripts/plum/infer_default.py
+++ b/scripts/plum/infer_default.py
@@ -12,8 +12,18 @@
from data import EmbeddingDataset, ProcessEmbeddings
from models import PlumRQVAE
-from transforms import AddWeightedCooccurrenceEmbeddings
-from cooc_data import CoocMappingDataset
+
+# ПУТИ
+IREC_PATH = '/home/jovyan/IRec/'
+EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl'
+MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_plum_rqvae_beauty_ws_2_best_0.0051.pth'
+RESULTS_PATH = os.path.join(IREC_PATH, 'results')
+
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+
+# ОСТАЛЬНОЕ
SEED_VALUE = 42
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
@@ -26,29 +36,16 @@
NUM_CODEBOOKS = 3
BETA = 0.25
-MODEL_PATH = '/home/jovyan/IRec/checkpoints/test_plum_rqvae_beauty_ws_2_best_0.0054.pth'
-WINDOW_SIZE = 2
-
-EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
-
-IREC_PATH = '/home/jovyan/IRec/'
def main():
fix_random_seed(SEED_VALUE)
- data = CoocMappingDataset.create(
- inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'),
- max_sequence_length=20,
- sampler_type='sasrec',
- window_size=WINDOW_SIZE
- )
-
dataset = EmbeddingDataset(
- data_path='/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl'
+ data_path=EMBEDDINGS_PATH
)
-
+
item_id_to_embedding = {}
all_item_ids = []
for idx in range(len(dataset)):
@@ -57,15 +54,12 @@ def main():
item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
all_item_ids.append(item_id)
- add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
- data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
-
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=False,
drop_last=False,
- ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']))
model = PlumRQVAE(
input_dim=INPUT_DIM,
@@ -106,8 +100,8 @@ def main():
cb.Logger().every_num_steps(len(dataloader)),
cb.InferenceSaver(
- metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
- save_path=f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json',
+ metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']},
+ save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'),
format='json'
)
]
@@ -125,9 +119,9 @@ def main():
from collections import defaultdict
import numpy as np
- with open(f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', 'r') as f:
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f:
mappings = json.load(f)
-
+
inter = {}
sem_2_ids = defaultdict(list)
for mapping in mappings:
@@ -143,8 +137,8 @@ def main():
inter[item_id].append(collision_solver)
for i in range(len(inter[item_id])):
inter[item_id][i] += CODEBOOK_SIZE * i
-
- with open(os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
+
+ with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f:
json.dump(inter, f, indent=2)
diff --git a/scripts/plum/train_plum.py b/scripts/plum/train_plum.py
index 5a00bc3..ffa9e43 100644
--- a/scripts/plum/train_plum.py
+++ b/scripts/plum/train_plum.py
@@ -41,8 +41,6 @@ def main():
data = CoocMappingDataset.create(
inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'),
- max_sequence_length=20,
- sampler_type='sasrec',
window_size=WINDOW_SIZE
)
diff --git a/scripts/plum/train_plum_timestamp_based.py b/scripts/plum/train_plum_timestamp_based.py
new file mode 100644
index 0000000..e755d95
--- /dev/null
+++ b/scripts/plum/train_plum_timestamp_based.py
@@ -0,0 +1,168 @@
+from loguru import logger
+import os
+
+import torch
+
+import pickle
+
+import irec.callbacks as cb
+from irec.data.dataloader import DataLoader
+from irec.data.transforms import Collate, ToTorch, ToDevice
+from irec.runners import TrainingRunner
+
+from irec.utils import fix_random_seed
+
+from callbacks import InitCodebooks, FixDeadCentroids
+from data import EmbeddingDataset, ProcessEmbeddings
+from models import PlumRQVAE
+from transforms import AddWeightedCooccurrenceEmbeddings
+from cooc_data import CoocMappingDataset
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+NUM_EPOCHS = 500
+BATCH_SIZE = 1024
+
+INPUT_DIM = 4096
+HIDDEN_DIM = 32
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 3
+BETA = 0.25
+LR = 1e-4
+WINDOW_SIZE = 2
+
+EXPERIMENT_NAME = f'4-1_plum_rqvae_beauty_ws_{WINDOW_SIZE}'
+INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json"
+EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl"
+IREC_PATH = '../../'
+
+def main():
+ fix_random_seed(SEED_VALUE)
+
+ data = CoocMappingDataset.create_from_split_part(
+ train_inter_json_path=INTER_TRAIN_PATH,
+ window_size=WINDOW_SIZE
+ )
+
+ dataset = EmbeddingDataset(
+ data_path=EMBEDDINGS_PATH
+ )
+
+ item_id_to_embedding = {}
+ all_item_ids = []
+ for idx in range(len(dataset)):
+ sample = dataset[idx]
+ item_id = int(sample['item_id'])
+ item_id_to_embedding[item_id] = torch.tensor(sample['embedding'])
+ all_item_ids.append(item_id)
+
+ add_cooc_transform = AddWeightedCooccurrenceEmbeddings(
+ data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids)
+
+ train_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=True,
+ drop_last=True,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(
+ ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])
+ ).map(add_cooc_transform).repeat(NUM_EPOCHS)
+
+ valid_dataloader = DataLoader(
+ dataset,
+ batch_size=BATCH_SIZE,
+ shuffle=False,
+ drop_last=False,
+ ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform)
+
+ LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS)
+
+ model = PlumRQVAE(
+ input_dim=INPUT_DIM,
+ num_codebooks=NUM_CODEBOOKS,
+ codebook_size=CODEBOOK_SIZE,
+ embedding_dim=HIDDEN_DIM,
+ beta=BETA,
+ quant_loss_weight=1.0,
+ contrastive_loss_weight=1.0,
+ temperature=1.0
+ ).to(DEVICE)
+
+ total_params = sum(p.numel() for p in model.parameters())
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+ logger.debug(f'Overall parameters: {total_params:,}')
+ logger.debug(f'Trainable parameters: {trainable_params:,}')
+
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True)
+
+ callbacks = [
+ InitCodebooks(valid_dataloader),
+
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='train'),
+
+ FixDeadCentroids(valid_dataloader),
+
+ cb.MetricAccumulator(
+ accumulators={
+ 'train/loss': cb.MeanAccumulator(),
+ 'train/recon_loss': cb.MeanAccumulator(),
+ 'train/rqvae_loss': cb.MeanAccumulator(),
+ 'train/con_loss': cb.MeanAccumulator(),
+ 'num_dead/0': cb.MeanAccumulator(),
+ 'num_dead/1': cb.MeanAccumulator(),
+ 'num_dead/2': cb.MeanAccumulator(),
+ },
+ reset_every_num_steps=LOG_EVERY_NUM_STEPS
+ ),
+
+ cb.Validation(
+ dataset=valid_dataloader,
+ callbacks=[
+ cb.BatchMetrics(metrics=lambda model_outputs, batch: {
+ 'loss': model_outputs['loss'],
+ 'recon_loss': model_outputs['recon_loss'],
+ 'rqvae_loss': model_outputs['rqvae_loss'],
+ 'con_loss': model_outputs['con_loss']
+ }, name='valid'),
+ cb.MetricAccumulator(
+ accumulators={
+ 'valid/loss': cb.MeanAccumulator(),
+ 'valid/recon_loss': cb.MeanAccumulator(),
+ 'valid/rqvae_loss': cb.MeanAccumulator(),
+ 'valid/con_loss': cb.MeanAccumulator()
+ }
+ ),
+ ],
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+
+ cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+
+ cb.EarlyStopping(
+ metric='valid/recon_loss',
+ patience=40,
+ minimize=True,
+ model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ ).every_num_steps(LOG_EVERY_NUM_STEPS),
+ ]
+
+ logger.debug('Everything is ready for training process!')
+
+ runner = TrainingRunner(
+ model=model,
+ optimizer=optimizer,
+ dataset=train_dataloader,
+ callbacks=callbacks,
+ )
+ runner.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/scripts/tiger/data.py b/scripts/tiger/data.py
index 188993a..760fe8c 100644
--- a/scripts/tiger/data.py
+++ b/scripts/tiger/data.py
@@ -28,6 +28,113 @@ def __init__(
self._num_items = num_items
self._max_sequence_length = max_sequence_length
+ @classmethod
+ def create_timestamp_based(
+ cls,
+ train_json_path,
+ validation_json_path,
+ test_json_path,
+ max_sequence_length,
+ sampler_type,
+ min_sample_len=2,
+ is_extended=False
+ ):
+ max_item_id = 0
+ train_dataset, validation_dataset, test_dataset = [], [], []
+
+ with open(train_json_path, 'r') as f:
+ train_data = json.load(f)
+ with open(validation_json_path, 'r') as f:
+ validation_data = json.load(f)
+ with open(test_json_path, 'r') as f:
+ test_data = json.load(f)
+
+ all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys())
+
+ for user_id_str in all_users:
+ user_id = int(user_id_str)
+
+ train_items = train_data.get(user_id_str, [])
+ validation_items = validation_data.get(user_id_str, [])
+ test_items = test_data.get(user_id_str, [])
+
+ full_sequence = train_items + validation_items + test_items
+ if full_sequence:
+ max_item_id = max(max_item_id, max(full_sequence))
+
+ assert len(full_sequence) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(full_sequence)} items'
+
+ if is_extended:
+ # sample = [1, 2]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ # sample = [1, 2, 3, 4, 5, 6, 7]
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ for prefix_length in range(min_sample_len, len(train_items) + 1):
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': train_items[:prefix_length],
+ })
+ else:
+ # sample = [1, 2, 3, 4, 5, 6, 7, 8]
+ train_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': train_items,
+ })
+
+ # разворачиваем каждый айтем из валидации в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4]
+ # sample = [1, 2, 3]
+ # sample = [1, 2, 3, 4]
+
+ current_history = train_items.copy()
+ for item in validation_items:
+ # эвал датасет сам отрезает таргет потом
+ sample_sequence = current_history + [item]
+
+ if len(sample_sequence) >= min_sample_len:
+ validation_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': sample_sequence,
+ })
+ current_history.append(item)
+
+ # разворачиваем каждый айтем из теста в отдельный сэмпл
+ # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6]
+ # sample = [1, 2, 3, 4, 5]
+ # sample = [1, 2, 3, 4, 5, 6]
+ current_history = train_items + validation_items
+
+ for item in test_items:
+ # эвал датасет сам отрезает таргет потом
+ sample_sequence = current_history + [item]
+
+ if len(sample_sequence) >= min_sample_len:
+ test_dataset.append({
+ 'user.ids': [user_id],
+ 'item.ids': sample_sequence,
+ })
+
+ current_history.append(item)
+
+ logger.debug(f'Train dataset size: {len(train_dataset)}')
+ logger.debug(f'Validation dataset size: {len(validation_dataset)}')
+ logger.debug(f'Test dataset size: {len(test_dataset)}')
+
+ train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length)
+ validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length)
+ test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length)
+
+ return cls(
+ train_sampler=train_sampler,
+ validation_sampler=validation_sampler,
+ test_sampler=test_sampler,
+ num_items=max_item_id + 1, # +1 added because our ids are 0-indexed
+ max_sequence_length=max_sequence_length
+ )
+
@classmethod
def create(cls, inter_json_path, max_sequence_length, sampler_type, is_extended=False):
max_item_id = 0
diff --git a/scripts/tiger/train.py b/scripts/tiger/train.py
index f436dd4..1a2d347 100644
--- a/scripts/tiger/train.py
+++ b/scripts/tiger/train.py
@@ -14,10 +14,23 @@
from data import ArrowBatchDataset
from models import TigerModel, CorrectItemsLogitsProcessor
+
+# ПУТИ
+IREC_PATH = '../../'
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_plum_rqvae_beauty_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_eval_batches/')
+
+TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs')
+CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints')
+
+EXPERIMENT_NAME = 'tiger_beauty_4-1_plum_ws_2_dp_0.2'
+
+# ОСТАЛЬНОЕ
SEED_VALUE = 42
DEVICE = 'cuda'
-EXPERIMENT_NAME = 'tiger_beauty'
NUM_EPOCHS = 300
MAX_SEQ_LEN = 20
TRAIN_BATCH_SIZE = 256
@@ -30,13 +43,12 @@
NUM_LAYERS = 4
FEEDFORWARD_DIM = 1024
KV_DIM = 64
-DROPOUT = 0.1
+DROPOUT = 0.2
NUM_BEAMS = 30
TOP_K = 20
NUM_CODEBOOKS = 4
-LR = 3e-4
+LR = 0.0001
-IREC_PATH = '../../'
torch.set_float32_matmul_precision('high')
torch._dynamo.config.capture_scalar_outputs = True
@@ -48,30 +60,30 @@
def main():
fix_random_seed(SEED_VALUE)
- with open(os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json'), 'r') as f:
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
mappings = json.load(f)
-
+
train_dataloader = DataLoader(
ArrowBatchDataset(
- os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/'),
- device='cpu',
+ TRAIN_BATCHES_DIR,
+ device='cpu',
preload=True
),
- batch_size=1,
- shuffle=True,
+ batch_size=1,
+ shuffle=True,
num_workers=0,
- pin_memory=True,
+ pin_memory=True,
collate_fn=Collate()
).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS)
valid_dataloder = ArrowBatchDataset(
- os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/'),
+ VALID_BATCHES_DIR,
device=DEVICE,
preload=True
)
eval_dataloder = ArrowBatchDataset(
- os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/'),
+ EVAL_BATCHES_DIR,
device=DEVICE,
preload=True
)
@@ -177,22 +189,22 @@ def main():
),
],
).every_num_steps(EPOCH_NUM_STEPS),
-
+
cb.Logger().every_num_steps(EPOCH_NUM_STEPS),
- cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')),
+ cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR),
cb.EarlyStopping(
- metric='eval/ndcg@20',
+ metric='eval/ndcg@20',
patience=40,
minimize=False,
- model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME)
+ model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME)
).every_num_steps(EPOCH_NUM_STEPS)
# cb.Profiler(
# wait=10,
# warmup=10,
# active=10,
- # logdir=os.path.join(IREC_PATH, 'tensorboard_logs')
+ # logdir=TENSORBOARD_LOGDIR
# ),
# cb.StopAfterNumSteps(40)
diff --git a/scripts/tiger/varka.py b/scripts/tiger/varka.py
index ed47595..4dc3e02 100644
--- a/scripts/tiger/varka.py
+++ b/scripts/tiger/varka.py
@@ -15,6 +15,20 @@
from data import Dataset
+
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_PATH = os.path.join(IREC_PATH, 'data/Beauty/inter.json')
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/')
+
+
+# ОСТАЛЬНОЕ
+
SEED_VALUE = 42
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
@@ -32,8 +46,6 @@
DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
-IREC_PATH = '../../'
-
class TigerProcessing(Transform):
def __call__(self, batch):
@@ -42,12 +54,12 @@ def __call__(self, batch):
input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
- input_semantic_ids = np.concat([
+ input_semantic_ids = np.concatenate([
input_semantic_ids,
NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
], axis=-1)
- attention_mask = np.concat([
+ attention_mask = np.concatenate([
attention_mask,
np.ones((batch_size, 1), dtype=attention_mask.dtype)
], axis=-1)
@@ -56,7 +68,7 @@ def __call__(self, batch):
batch['input.mask'] = attention_mask
target_semantic_ids = batch['labels.semantic.padded']
- target_semantic_ids = np.concat([
+ target_semantic_ids = np.concatenate([
np.ones(
(batch_size, 1),
dtype=np.int64,
@@ -73,7 +85,7 @@ class ToMasked(Transform):
def __init__(self, prefix, is_right_aligned=False):
self._prefix = prefix
self._is_right_aligned = is_right_aligned
-
+
def __call__(self, batch):
data = batch[f'{self._prefix}.ids']
lengths = batch[f'{self._prefix}.length']
@@ -92,7 +104,7 @@ def __call__(self, batch):
(batch_size, max_sequence_length, data.shape[-1]),
dtype=data.dtype
) # (batch_size, max_seq_len, emb_dim)
-
+
mask = np.arange(max_sequence_length)[None] < lengths[:, None]
if self._is_right_aligned:
@@ -117,10 +129,10 @@ def __init__(self, mapping, names=[]):
data.append(mapping[str(i)])
self._mapping_tensor = torch.tensor(data, dtype=torch.long)
self._semantic_length = self._mapping_tensor.shape[-1]
-
+
def __call__(self, batch):
for name in self._names:
- if f'{name}.ids' in batch:
+ if f'{name}.ids' in batch:
ids = batch[f'{name}.ids']
lengths = batch[f'{name}.length']
assert ids.min() >= 0
@@ -135,7 +147,7 @@ class UserHashing(Transform):
def __init__(self, hash_size):
super().__init__()
self._hash_size = hash_size
-
+
def __call__(self, batch):
batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
return batch
@@ -144,7 +156,7 @@ def __call__(self, batch):
def save_batches_to_arrow(batches, output_dir):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=False)
-
+
for batch_idx, batch in enumerate(batches):
length_groups = defaultdict(dict)
metadata_groups = defaultdict(dict)
@@ -164,7 +176,7 @@ def save_batches_to_arrow(batches, output_dir):
else:
# >2D массив - flatten и сохраняем shape
length_groups[length][key] = value.flatten()
-
+
for length, fields in length_groups.items():
arrow_dict = {}
for k, v in fields.items():
@@ -173,11 +185,11 @@ def save_batches_to_arrow(batches, output_dir):
arrow_dict[k] = pa.array(v)
else:
arrow_dict[k] = pa.array(v)
-
+
table = pa.table(arrow_dict)
if length in metadata_groups:
table = table.replace_schema_metadata(metadata_groups[length])
-
+
feather.write_feather(
table,
output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
@@ -186,7 +198,7 @@ def save_batches_to_arrow(batches, output_dir):
# arrow_dict = {k: pa.array(v) for k, v in fields.items()}
# table = pa.table(arrow_dict)
-
+
# feather.write_feather(
# table,
# output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
@@ -196,15 +208,15 @@ def save_batches_to_arrow(batches, output_dir):
def main():
data = Dataset.create(
- inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter.json'),
+ inter_json_path=INTERACTIONS_PATH,
max_sequence_length=MAX_SEQ_LEN,
sampler_type='tiger',
is_extended=True
)
- with open(os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json'), 'r') as f:
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
mappings = json.load(f)
-
+
train_dataset, valid_dataset, eval_dataset = data.get_datasets()
train_dataloader = DataLoader(
@@ -219,7 +231,7 @@ def main():
.map(ToMasked('item.semantic', is_right_aligned=True)) \
.map(ToMasked('labels.semantic', is_right_aligned=True)) \
.map(TigerProcessing())
-
+
valid_dataloader = DataLoader(
dataset=valid_dataset,
batch_size=VALID_BATCH_SIZE,
@@ -251,17 +263,18 @@ def main():
train_batches = []
for train_batch in train_dataloader:
train_batches.append(train_batch)
- save_batches_to_arrow(train_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/'))
-
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
valid_batches = []
for valid_batch in valid_dataloader:
valid_batches.append(valid_batch)
- save_batches_to_arrow(valid_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/'))
-
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
eval_batches = []
for eval_batch in eval_dataloader:
eval_batches.append(eval_batch)
- save_batches_to_arrow(eval_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/'))
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
if __name__ == '__main__':
diff --git a/scripts/tiger/varka_timestamp_based.py b/scripts/tiger/varka_timestamp_based.py
new file mode 100644
index 0000000..11343ea
--- /dev/null
+++ b/scripts/tiger/varka_timestamp_based.py
@@ -0,0 +1,287 @@
+from collections import defaultdict
+import json
+import murmurhash
+import numpy as np
+import os
+from pathlib import Path
+
+import pyarrow as pa
+import pyarrow.feather as feather
+
+import torch
+
+from irec.data.transforms import Collate, Transform
+from irec.data.dataloader import DataLoader
+
+from data import Dataset
+
+
+
+# ПУТИ
+
+IREC_PATH = '../../'
+INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json')
+INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/valid_skip_set.json')
+INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/test_set.json')
+
+SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_plum_rqvae_beauty_ws_2_clusters_colisionless.json')
+TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_train_batches/')
+VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_valid_batches/')
+EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_eval_batches/')
+
+
+# ОСТАЛЬНОЕ
+
+SEED_VALUE = 42
+DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+
+
+MAX_SEQ_LEN = 20
+TRAIN_BATCH_SIZE = 256
+VALID_BATCH_SIZE = 1024
+NUM_USER_HASH = 2000
+CODEBOOK_SIZE = 256
+NUM_CODEBOOKS = 4
+
+UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities
+PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1,
+EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2,
+DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3,
+
+
+
+class TigerProcessing(Transform):
+ def __call__(self, batch):
+ input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask']
+ batch_size = attention_mask.shape[0]
+
+ input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ???
+
+ input_semantic_ids = np.concatenate([
+ input_semantic_ids,
+ NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None]
+ ], axis=-1)
+
+ attention_mask = np.concatenate([
+ attention_mask,
+ np.ones((batch_size, 1), dtype=attention_mask.dtype)
+ ], axis=-1)
+
+ batch['input.data'] = input_semantic_ids
+ batch['input.mask'] = attention_mask
+
+ target_semantic_ids = batch['labels.semantic.padded']
+ target_semantic_ids = np.concatenate([
+ np.ones(
+ (batch_size, 1),
+ dtype=np.int64,
+ ) * DECODER_START_TOKEN_ID,
+ target_semantic_ids
+ ], axis=-1)
+
+ batch['output.data'] = target_semantic_ids
+
+ return batch
+
+
+class ToMasked(Transform):
+ def __init__(self, prefix, is_right_aligned=False):
+ self._prefix = prefix
+ self._is_right_aligned = is_right_aligned
+
+ def __call__(self, batch):
+ data = batch[f'{self._prefix}.ids']
+ lengths = batch[f'{self._prefix}.length']
+
+ batch_size = lengths.shape[0]
+ max_sequence_length = int(lengths.max())
+
+ if len(data.shape) == 1: # only indices
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len)
+ else:
+ assert len(data.shape) == 2 # embeddings
+ padded_tensor = np.zeros(
+ (batch_size, max_sequence_length, data.shape[-1]),
+ dtype=data.dtype
+ ) # (batch_size, max_seq_len, emb_dim)
+
+ mask = np.arange(max_sequence_length)[None] < lengths[:, None]
+
+ if self._is_right_aligned:
+ mask = np.flip(mask, axis=-1)
+
+ padded_tensor[mask] = data
+
+ batch[f'{self._prefix}.padded'] = padded_tensor
+ batch[f'{self._prefix}.mask'] = mask
+
+ return batch
+
+
+class SemanticIdsMapper(Transform):
+ def __init__(self, mapping, names=[]):
+ super().__init__()
+ self._mapping = mapping
+ self._names = names
+
+ data = []
+ for i in range(len(mapping)):
+ data.append(mapping[str(i)])
+ self._mapping_tensor = torch.tensor(data, dtype=torch.long)
+ self._semantic_length = self._mapping_tensor.shape[-1]
+
+ def __call__(self, batch):
+ for name in self._names:
+ if f'{name}.ids' in batch:
+ ids = batch[f'{name}.ids']
+ lengths = batch[f'{name}.length']
+ assert ids.min() >= 0
+ assert ids.max() < self._mapping_tensor.shape[0]
+ batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy()
+ batch[f'{name}.semantic.length'] = lengths * self._semantic_length
+
+ return batch
+
+
+class UserHashing(Transform):
+ def __init__(self, hash_size):
+ super().__init__()
+ self._hash_size = hash_size
+
+ def __call__(self, batch):
+ batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64)
+ return batch
+
+
+def save_batches_to_arrow(batches, output_dir):
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ for batch_idx, batch in enumerate(batches):
+ length_groups = defaultdict(dict)
+ metadata_groups = defaultdict(dict)
+
+ for key, value in batch.items():
+ length = len(value)
+
+ metadata_groups[length][f'{key}_shape'] = str(value.shape)
+ metadata_groups[length][f'{key}_dtype'] = str(value.dtype)
+
+ if value.ndim == 1:
+ # 1D массив - сохраняем как есть
+ length_groups[length][key] = value
+ elif value.ndim == 2:
+ # 2D массив - используем list of lists
+ length_groups[length][key] = value.tolist()
+ else:
+ # >2D массив - flatten и сохраняем shape
+ length_groups[length][key] = value.flatten()
+
+ for length, fields in length_groups.items():
+ arrow_dict = {}
+ for k, v in fields.items():
+ if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list):
+ # List of lists (2D)
+ arrow_dict[k] = pa.array(v)
+ else:
+ arrow_dict[k] = pa.array(v)
+
+ table = pa.table(arrow_dict)
+ if length in metadata_groups:
+ table = table.replace_schema_metadata(metadata_groups[length])
+
+ feather.write_feather(
+ table,
+ output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ compression='lz4'
+ )
+
+ # arrow_dict = {k: pa.array(v) for k, v in fields.items()}
+ # table = pa.table(arrow_dict)
+
+ # feather.write_feather(
+ # table,
+ # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow",
+ # compression='lz4'
+ # )
+
+
+def main():
+ data = Dataset.create_timestamp_based(
+ train_json_path=INTERACTIONS_TRAIN_PATH,
+ validation_json_path=INTERACTIONS_VALID_PATH,
+ test_json_path=INTERACTIONS_TEST_PATH,
+ max_sequence_length=MAX_SEQ_LEN,
+ sampler_type='tiger',
+ min_sample_len=2,
+ is_extended=True
+ )
+
+ with open(SEMANTIC_MAPPING_PATH, 'r') as f:
+ mappings = json.load(f)
+
+ train_dataset, valid_dataset, eval_dataset = data.get_datasets()
+
+ train_dataloader = DataLoader(
+ dataset=train_dataset,
+ batch_size=TRAIN_BATCH_SIZE,
+ shuffle=True,
+ drop_last=True
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ valid_dataloader = DataLoader(
+ dataset=valid_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ eval_dataloader = DataLoader(
+ dataset=eval_dataset,
+ batch_size=VALID_BATCH_SIZE,
+ shuffle=False,
+ drop_last=False
+ ) \
+ .map(Collate()) \
+ .map(UserHashing(NUM_USER_HASH)) \
+ .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \
+ .map(ToMasked('item.semantic', is_right_aligned=True)) \
+ .map(ToMasked('labels.semantic', is_right_aligned=True)) \
+ .map(ToMasked('visited', is_right_aligned=True)) \
+ .map(TigerProcessing())
+
+ train_batches = []
+ for train_batch in train_dataloader:
+ train_batches.append(train_batch)
+ save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR)
+
+ valid_batches = []
+ for valid_batch in valid_dataloader:
+ valid_batches.append(valid_batch)
+ save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR)
+
+ eval_batches = []
+ for eval_batch in eval_dataloader:
+ eval_batches.append(eval_batch)
+ save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR)
+
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sigir/DatasetProcessing.ipynb b/sigir/DatasetProcessing.ipynb
new file mode 100644
index 0000000..09b8d21
--- /dev/null
+++ b/sigir/DatasetProcessing.ipynb
@@ -0,0 +1,727 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "3bdb292f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from collections import defaultdict\n",
+ "\n",
+ "import json\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import pickle\n",
+ "import polars as pl\n",
+ "\n",
+ "from transformers import LlamaModel, LlamaTokenizer\n",
+ "\n",
+ "import torch\n",
+ "from torch.utils.data import DataLoader\n",
+ "\n",
+ "from tqdm import tqdm as tqdm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "66d9b312",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "interactions_dataset_path = '../data/Beauty/Beauty_5.json'\n",
+ "metadata_path = '../data/Beauty/metadata.json'\n",
+ "\n",
+ "interactions_output_json_path = '../data/Beauty_new/inter_new.json'\n",
+ "interactions_output_parquet_path = '../data/Beauty_new/inter_new.parquet'\n",
+ "embeddings_output_path = '../data/Beauty_new/content_embeddings.pkl'\n",
+ "item_ids_mapping_output_path = '../data/Beauty_new/item_ids_mapping.json'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "6ed4dffb",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of events: 198502\n"
+ ]
+ }
+ ],
+ "source": [
+ "df = defaultdict(list)\n",
+ "\n",
+ "with open(interactions_dataset_path, 'r') as f:\n",
+ " for line in f.readlines():\n",
+ " review = json.loads(line)\n",
+ " df['user_id'].append(review['reviewerID'])\n",
+ " df['item_id'].append(review['asin'])\n",
+ " df['timestamp'].append(review['unixReviewTime'])\n",
+ "\n",
+ "print(f'Number of events: {len(df[\"user_id\"])}')\n",
+ "\n",
+ "df = pl.from_dict(df)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "c26746c4",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "
shape: (5, 3)| user_id | item_id | timestamp |
|---|
| str | str | i64 |
| "A1YJEY40YUW4SE" | "7806397051" | 1391040000 |
| "A60XNB876KYML" | "7806397051" | 1397779200 |
| "A3G6XNM240RMWA" | "7806397051" | 1378425600 |
| "A1PQFP6SAJ6D80" | "7806397051" | 1386460800 |
| "A38FVHZTNQ271F" | "7806397051" | 1382140800 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌────────────────┬────────────┬────────────┐\n",
+ "│ user_id ┆ item_id ┆ timestamp │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ str ┆ str ┆ i64 │\n",
+ "╞════════════════╪════════════╪════════════╡\n",
+ "│ A1YJEY40YUW4SE ┆ 7806397051 ┆ 1391040000 │\n",
+ "│ A60XNB876KYML ┆ 7806397051 ┆ 1397779200 │\n",
+ "│ A3G6XNM240RMWA ┆ 7806397051 ┆ 1378425600 │\n",
+ "│ A1PQFP6SAJ6D80 ┆ 7806397051 ┆ 1386460800 │\n",
+ "│ A38FVHZTNQ271F ┆ 7806397051 ┆ 1382140800 │\n",
+ "└────────────────┴────────────┴────────────┘"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "adcf5713",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "filtered_df = df.clone()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "c0bbf9ba",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Processing dataset to get core-5 state in case full dataset is provided\n",
+ "is_changed = True\n",
+ "threshold = 5\n",
+ "good_users = set()\n",
+ "good_items = set()\n",
+ "\n",
+ "while is_changed:\n",
+ " user_counts = filtered_df.group_by('user_id').agg(\n",
+ " pl.len().alias('user_count'),\n",
+ " )\n",
+ " item_counts = filtered_df.group_by('item_id').agg(\n",
+ " pl.len().alias('item_count'),\n",
+ " )\n",
+ "\n",
+ " good_users = user_counts.filter(pl.col('user_count') >= threshold).select(\n",
+ " 'user_id',\n",
+ " )\n",
+ " good_items = item_counts.filter(pl.col('item_count') >= threshold).select(\n",
+ " 'item_id',\n",
+ " )\n",
+ "\n",
+ " old_size = len(filtered_df)\n",
+ "\n",
+ " new_df = filtered_df.join(good_users, on='user_id', how='inner')\n",
+ " new_df = new_df.join(good_items, on='item_id', how='inner')\n",
+ "\n",
+ " new_size = len(new_df)\n",
+ "\n",
+ " filtered_df = new_df\n",
+ " is_changed = old_size != new_size\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "218a9348",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 3)| user_id | item_id | timestamp |
|---|
| i64 | i64 | i64 |
| 0 | 0 | 1391040000 |
| 1 | 0 | 1397779200 |
| 2 | 0 | 1378425600 |
| 3 | 0 | 1386460800 |
| 4 | 0 | 1382140800 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌─────────┬─────────┬────────────┐\n",
+ "│ user_id ┆ item_id ┆ timestamp │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ i64 ┆ i64 ┆ i64 │\n",
+ "╞═════════╪═════════╪════════════╡\n",
+ "│ 0 ┆ 0 ┆ 1391040000 │\n",
+ "│ 1 ┆ 0 ┆ 1397779200 │\n",
+ "│ 2 ┆ 0 ┆ 1378425600 │\n",
+ "│ 3 ┆ 0 ┆ 1386460800 │\n",
+ "│ 4 ┆ 0 ┆ 1382140800 │\n",
+ "└─────────┴─────────┴────────────┘"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "unique_values = filtered_df[\"user_id\"].unique(maintain_order=True).to_list()\n",
+ "user_ids_mapping = {value: i for i, value in enumerate(unique_values)}\n",
+ "\n",
+ "filtered_df = filtered_df.with_columns(\n",
+ " pl.col(\"user_id\").replace_strict(user_ids_mapping)\n",
+ ")\n",
+ "\n",
+ "unique_values = filtered_df[\"item_id\"].unique(maintain_order=True).to_list()\n",
+ "item_ids_mapping = {value: i for i, value in enumerate(unique_values)}\n",
+ "\n",
+ "filtered_df = filtered_df.with_columns(\n",
+ " pl.col(\"item_id\").replace_strict(item_ids_mapping)\n",
+ ")\n",
+ "\n",
+ "filtered_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "34604fe6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 2)| old_item_id | new_item_id |
|---|
| str | i64 |
| "7806397051" | 0 |
| "9759091062" | 1 |
| "9788072216" | 2 |
| "9790790961" | 3 |
| "9790794231" | 4 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 2)\n",
+ "┌─────────────┬─────────────┐\n",
+ "│ old_item_id ┆ new_item_id │\n",
+ "│ --- ┆ --- │\n",
+ "│ str ┆ i64 │\n",
+ "╞═════════════╪═════════════╡\n",
+ "│ 7806397051 ┆ 0 │\n",
+ "│ 9759091062 ┆ 1 │\n",
+ "│ 9788072216 ┆ 2 │\n",
+ "│ 9790790961 ┆ 3 │\n",
+ "│ 9790794231 ┆ 4 │\n",
+ "└─────────────┴─────────────┘"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "item_ids_mapping_df = pl.from_dict({\n",
+ " 'old_item_id': list(item_ids_mapping.keys()),\n",
+ " 'new_item_id': list(item_ids_mapping.values())\n",
+ "})\n",
+ "item_ids_mapping_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "99b54db807b9495c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(item_ids_mapping_output_path, 'w') as f:\n",
+ " json.dump({str(k): v for k, v in item_ids_mapping.items()}, f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "6017e65c",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 3)| user_id | item_id | timestamp |
|---|
| i64 | i64 | i64 |
| 0 | 0 | 1391040000 |
| 1 | 0 | 1397779200 |
| 2 | 0 | 1378425600 |
| 3 | 0 | 1386460800 |
| 4 | 0 | 1382140800 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌─────────┬─────────┬────────────┐\n",
+ "│ user_id ┆ item_id ┆ timestamp │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ i64 ┆ i64 ┆ i64 │\n",
+ "╞═════════╪═════════╪════════════╡\n",
+ "│ 0 ┆ 0 ┆ 1391040000 │\n",
+ "│ 1 ┆ 0 ┆ 1397779200 │\n",
+ "│ 2 ┆ 0 ┆ 1378425600 │\n",
+ "│ 3 ┆ 0 ┆ 1386460800 │\n",
+ "│ 4 ┆ 0 ┆ 1382140800 │\n",
+ "└─────────┴─────────┴────────────┘"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "filtered_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "9efd1983",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "filtered_df = filtered_df.sort([\"user_id\", \"timestamp\"])\n",
+ "\n",
+ "grouped_filtered_df = filtered_df.group_by(\"user_id\", maintain_order=True).agg(pl.all())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "fd51c525",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 2)| old_item_id | new_item_id |
|---|
| str | i64 |
| "7806397051" | 0 |
| "9759091062" | 1 |
| "9788072216" | 2 |
| "9790790961" | 3 |
| "9790794231" | 4 |
"
+ ],
+ "text/plain": [
+ "shape: (5, 2)\n",
+ "┌─────────────┬─────────────┐\n",
+ "│ old_item_id ┆ new_item_id │\n",
+ "│ --- ┆ --- │\n",
+ "│ str ┆ i64 │\n",
+ "╞═════════════╪═════════════╡\n",
+ "│ 7806397051 ┆ 0 │\n",
+ "│ 9759091062 ┆ 1 │\n",
+ "│ 9788072216 ┆ 2 │\n",
+ "│ 9790790961 ┆ 3 │\n",
+ "│ 9790794231 ┆ 4 │\n",
+ "└─────────────┴─────────────┘"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "item_ids_mapping_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "8b0821da",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
shape: (5, 3)| user_id | item_id | timestamp |
|---|
| i64 | list[i64] | list[i64] |
| 0 | [6845, 7872, … 0] | [1318896000, 1318896000, … 1391040000] |
| 1 | [815, 10405, … 232] | [1392422400, 1396224000, … 1397779200] |
| 2 | [6049, 0, … 6608] | [1378425600, 1378425600, … 1400284800] |
| 3 | [5521, 5160, … 0] | [1379116800, 1380931200, … 1386460800] |
| 4 | [0, 10469, … 11389] | [1382140800, 1383523200, … 1388966400] |
"
+ ],
+ "text/plain": [
+ "shape: (5, 3)\n",
+ "┌─────────┬─────────────────────┬─────────────────────────────────┐\n",
+ "│ user_id ┆ item_id ┆ timestamp │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ i64 ┆ list[i64] ┆ list[i64] │\n",
+ "╞═════════╪═════════════════════╪═════════════════════════════════╡\n",
+ "│ 0 ┆ [6845, 7872, … 0] ┆ [1318896000, 1318896000, … 139… │\n",
+ "│ 1 ┆ [815, 10405, … 232] ┆ [1392422400, 1396224000, … 139… │\n",
+ "│ 2 ┆ [6049, 0, … 6608] ┆ [1378425600, 1378425600, … 140… │\n",
+ "│ 3 ┆ [5521, 5160, … 0] ┆ [1379116800, 1380931200, … 138… │\n",
+ "│ 4 ┆ [0, 10469, … 11389] ┆ [1382140800, 1383523200, … 138… │\n",
+ "└─────────┴─────────────────────┴─────────────────────────────────┘"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "grouped_filtered_df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "dc222d59",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Users count: 22363\n",
+ "Items count: 12101\n",
+ "Actions count: 198502\n",
+ "Avg user history len: 8.876358270357287\n"
+ ]
+ }
+ ],
+ "source": [
+ "print('Users count:', filtered_df.select('user_id').unique().shape[0])\n",
+ "print('Items count:', filtered_df.select('item_id').unique().shape[0])\n",
+ "print('Actions count:', filtered_df.shape[0])\n",
+ "print('Avg user history len:', np.mean(list(map(lambda x: x[0], grouped_filtered_df.select(pl.col('item_id').list.len()).rows()))))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "a272855d-84b2-4414-ba9f-62647e1151cf",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "shape: (5, 3)\n",
+ "┌─────┬─────────────────────┬─────────────────────────────────┐\n",
+ "│ uid ┆ item_ids ┆ timestamps │\n",
+ "│ --- ┆ --- ┆ --- │\n",
+ "│ i64 ┆ list[i64] ┆ list[i64] │\n",
+ "╞═════╪═════════════════════╪═════════════════════════════════╡\n",
+ "│ 0 ┆ [6845, 7872, … 0] ┆ [1318896000, 1318896000, … 139… │\n",
+ "│ 1 ┆ [815, 10405, … 232] ┆ [1392422400, 1396224000, … 139… │\n",
+ "│ 2 ┆ [6049, 0, … 6608] ┆ [1378425600, 1378425600, … 140… │\n",
+ "│ 3 ┆ [5521, 5160, … 0] ┆ [1379116800, 1380931200, … 138… │\n",
+ "│ 4 ┆ [0, 10469, … 11389] ┆ [1382140800, 1383523200, … 138… │\n",
+ "└─────┴─────────────────────┴─────────────────────────────────┘\n"
+ ]
+ }
+ ],
+ "source": [
+ "inter_new = grouped_filtered_df.select([\n",
+ " pl.col(\"user_id\").alias(\"uid\"),\n",
+ " pl.col(\"item_id\").alias(\"item_ids\"),\n",
+ " pl.col(\"timestamp\").alias(\"timestamps\")\n",
+ "])\n",
+ "\n",
+ "print(inter_new.head())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "de5a853a-8ee2-42dd-a71a-6cc6f90d526c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Файл успешно сохранен: ../data/Beauty_new/inter_new.parquet\n"
+ ]
+ }
+ ],
+ "source": [
+ "output_path_parquet = interactions_output_parquet_path\n",
+ "inter_new.write_parquet(output_path_parquet)\n",
+ "\n",
+ "print(f\"Файл успешно сохранен: {output_path_parquet}\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "d07a2e91",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "json_data = {}\n",
+ "for user_id, item_ids, _ in grouped_filtered_df.iter_rows():\n",
+ " json_data[user_id] = item_ids\n",
+ "\n",
+ "with open(interactions_output_json_path, 'w') as f:\n",
+ " json.dump(json_data, f, indent=2)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "237523fa",
+ "metadata": {
+ "jp-MarkdownHeadingCollapsed": true
+ },
+ "source": [
+ "## Content embedding creation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "6361c7a5",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
+ "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
+ "Cell \u001B[0;32mIn[19], line 5\u001B[0m, in \u001B[0;36mgetDF\u001B[0;34m(path)\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mopen\u001B[39m(path, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124m'\u001B[39m) \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[0;32m----> 5\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m line \u001B[38;5;129;01min\u001B[39;00m \u001B[43mf\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreadlines\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m:\n\u001B[1;32m 6\u001B[0m df[i] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28meval\u001B[39m(line)\n",
+ "File \u001B[0;32m/usr/lib/python3.10/codecs.py:319\u001B[0m, in \u001B[0;36mBufferedIncrementalDecoder.decode\u001B[0;34m(self, input, final)\u001B[0m\n\u001B[1;32m 317\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mNotImplementedError\u001B[39;00m\n\u001B[0;32m--> 319\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21mdecode\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m, final\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m):\n\u001B[1;32m 320\u001B[0m \u001B[38;5;66;03m# decode input (taking the buffer into account)\u001B[39;00m\n\u001B[1;32m 321\u001B[0m data \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuffer \u001B[38;5;241m+\u001B[39m \u001B[38;5;28minput\u001B[39m\n",
+ "\u001B[0;31mKeyboardInterrupt\u001B[0m: ",
+ "\nDuring handling of the above exception, another exception occurred:\n",
+ "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
+ "Cell \u001B[0;32mIn[19], line 11\u001B[0m\n\u001B[1;32m 7\u001B[0m i \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[1;32m 9\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m pd\u001B[38;5;241m.\u001B[39mDataFrame\u001B[38;5;241m.\u001B[39mfrom_dict(df, orient\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mindex\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m---> 11\u001B[0m df \u001B[38;5;241m=\u001B[39m \u001B[43mgetDF\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmetadata_path\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 12\u001B[0m df\u001B[38;5;241m.\u001B[39mhead()\n",
+ "Cell \u001B[0;32mIn[19], line 5\u001B[0m, in \u001B[0;36mgetDF\u001B[0;34m(path)\u001B[0m\n\u001B[1;32m 3\u001B[0m df \u001B[38;5;241m=\u001B[39m {}\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mopen\u001B[39m(path, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124m'\u001B[39m) \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[0;32m----> 5\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m line \u001B[38;5;129;01min\u001B[39;00m \u001B[43mf\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreadlines\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m:\n\u001B[1;32m 6\u001B[0m df[i] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28meval\u001B[39m(line)\n\u001B[1;32m 7\u001B[0m i \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n",
+ "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "def getDF(path):\n",
+ " i = 0\n",
+ " df = {}\n",
+ " with open(path, 'r') as f:\n",
+ " for line in f.readlines():\n",
+ " df[i] = eval(line)\n",
+ " i += 1\n",
+ "\n",
+ " return pd.DataFrame.from_dict(df, orient=\"index\")\n",
+ "\n",
+ "df = getDF(metadata_path)\n",
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "971fa89c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def preprocess(row: pd.Series):\n",
+ " row = row.fillna(\"None\")\n",
+ " return f\"Title: {row['title']}. Categories: {', '.join(row['categories'][0])}. Description: {row['description']}.\"\n",
+ "\n",
+ "\n",
+ "def get_data(metadata_df, item_ids_mapping_df):\n",
+ " filtered_df = metadata_df.join(\n",
+ " item_ids_mapping_df, \n",
+ " left_on=\"asin\", \n",
+ " right_on='old_item_id', \n",
+ " how=\"inner\"\n",
+ " ).select(pl.col('new_item_id'), pl.col('title'), pl.col('description'), pl.col('categories'))\n",
+ "\n",
+ " filtered_df = filtered_df.to_pandas()\n",
+ " filtered_df[\"combined_text\"] = filtered_df.apply(preprocess, axis=1)\n",
+ "\n",
+ " return filtered_df\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3b0dd5d5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = get_data(pl.from_pandas(df), item_ids_mapping)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "12e622ff",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device('cuda:6')\n",
+ "\n",
+ "model_name = \"huggyllama/llama-7b\"\n",
+ "tokenizer = LlamaTokenizer.from_pretrained(model_name)\n",
+ "\n",
+ "if tokenizer.pad_token is None:\n",
+ " tokenizer.pad_token = tokenizer.eos_token\n",
+ "\n",
+ "model = LlamaModel.from_pretrained(model_name)\n",
+ "model = model.to(device)\n",
+ "model = model.eval()\n",
+ "\n",
+ "\n",
+ "class MyDataset:\n",
+ " def __init__(self, data):\n",
+ " self._data = list(zip(data.to_dict()['new_item_id'].values(), data.to_dict()['combined_text'].values()))\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self._data)\n",
+ "\n",
+ " def __getitem__(self, idx):\n",
+ " text = self._data[idx][1]\n",
+ " inputs = tokenizer(text, return_tensors=\"pt\", max_length=1024, truncation=True, padding=\"max_length\")\n",
+ " return {\n",
+ " 'item_id': self._data[idx][0],\n",
+ " 'input_ids': inputs['input_ids'][0],\n",
+ " 'attention_mask': inputs['attention_mask'][0]\n",
+ " }\n",
+ " \n",
+ "\n",
+ "dataset = MyDataset(data)\n",
+ "loader = DataLoader(dataset, batch_size=8, drop_last=False, shuffle=False, num_workers=10)\n",
+ "\n",
+ "\n",
+ "new_df = {\n",
+ " 'item_id': [],\n",
+ " 'embedding': []\n",
+ "}\n",
+ "\n",
+ "for batch in tqdm(loader):\n",
+ " with torch.inference_mode():\n",
+ " outputs = model(\n",
+ " input_ids=batch[\"input_ids\"].to(device), \n",
+ " attention_mask=batch[\"attention_mask\"].to(device)\n",
+ " )\n",
+ " embeddings = outputs.last_hidden_state\n",
+ " \n",
+ " embeddings = outputs.last_hidden_state # (bs, sl, ed)\n",
+ " embeddings[(~batch[\"attention_mask\"].bool())] = 0. # (bs, sl, ed)\n",
+ "\n",
+ " new_df['item_id'] += batch['item_id'].tolist()\n",
+ " new_df['embedding'] += embeddings.mean(dim=1).tolist() # (bs, ed)\n",
+ "\n",
+ "\n",
+ "with open(embeddings_output_path, 'wb') as f:\n",
+ " pickle.dump(new_df, f)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a6fffc4a-85f1-424e-b460-29e526df3317",
+ "metadata": {},
+ "source": [
+ "# Test"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "id": "1f922431-e3c1-4587-86d1-04a58eb8ffee",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2025-11-29T16:36:01.750784Z",
+ "start_time": "2025-11-29T16:36:01.638604Z"
+ }
+ },
+ "source": [
+ "# df = pl.read_parquet(interactions_output_parquet_path)\n",
+ "# \n",
+ "# df_timestamps = df.select(\n",
+ "# pl.col(\"timestamps\").explode()\n",
+ "# )\n",
+ "# min_time = df_timestamps.select(pl.col(\"timestamps\").min()).item()\n",
+ "# max_time = df_timestamps.select(pl.col(\"timestamps\").max()).item()\n",
+ "# \n",
+ "# cutoffs = [\n",
+ "# min_time + (max_time - min_time) * 0.7, # 70%\n",
+ "# min_time + (max_time - min_time) * 0.8, # 80%\n",
+ "# min_time + (max_time - min_time) * 0.9, # 90%\n",
+ "# ]\n"
+ ],
+ "outputs": [
+ {
+ "ename": "NameError",
+ "evalue": "name 'pl' is not defined",
+ "output_type": "error",
+ "traceback": [
+ "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
+ "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)",
+ "Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m df \u001B[38;5;241m=\u001B[39m \u001B[43mpl\u001B[49m\u001B[38;5;241m.\u001B[39mread_parquet(interactions_output_parquet_path)\n\u001B[1;32m 3\u001B[0m df_timestamps \u001B[38;5;241m=\u001B[39m df\u001B[38;5;241m.\u001B[39mselect(\n\u001B[1;32m 4\u001B[0m pl\u001B[38;5;241m.\u001B[39mcol(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtimestamps\u001B[39m\u001B[38;5;124m\"\u001B[39m)\u001B[38;5;241m.\u001B[39mexplode()\n\u001B[1;32m 5\u001B[0m )\n\u001B[1;32m 6\u001B[0m min_time \u001B[38;5;241m=\u001B[39m df_timestamps\u001B[38;5;241m.\u001B[39mselect(pl\u001B[38;5;241m.\u001B[39mcol(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtimestamps\u001B[39m\u001B[38;5;124m\"\u001B[39m)\u001B[38;5;241m.\u001B[39mmin())\u001B[38;5;241m.\u001B[39mitem()\n",
+ "\u001B[0;31mNameError\u001B[0m: name 'pl' is not defined"
+ ]
+ }
+ ],
+ "execution_count": 1
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/sigir/exps_data.ipynb b/sigir/exps_data.ipynb
new file mode 100644
index 0000000..bf170a5
--- /dev/null
+++ b/sigir/exps_data.ipynb
@@ -0,0 +1,377 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "e2462a97-6705-44e1-a232-4dd78a5dfc85",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import polars as pl\n",
+ "import json\n",
+ "from typing import List, Dict"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "fd38624d-5796-4aa5-929f-7e82c5544f6c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "interactions_output_parquet_path = '../data/Beauty_new/inter_new.parquet'\n",
+ "df = pl.read_parquet(interactions_output_parquet_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "ee127317-66b8-4f22-9109-94bcb8b1f1ae",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def split_session_by_timestamps(\n",
+ " df: pl.DataFrame,\n",
+ " time_cutoffs: List[int],\n",
+ " output_dir: str = None,\n",
+ ") -> List[Dict[int, List[int]]]:\n",
+ " \n",
+ " result_dicts = []\n",
+ " \n",
+ " def extract_interval(df_source, start, end=None):\n",
+ " q = df_source.lazy()\n",
+ " q = q.explode([\"item_ids\", \"timestamps\"])\n",
+ " \n",
+ " if end is not None:\n",
+ " q = q.filter(\n",
+ " (pl.col(\"timestamps\") >= start) & \n",
+ " (pl.col(\"timestamps\") < end)\n",
+ " )\n",
+ " else:\n",
+ " q = q.filter(\n",
+ " pl.col(\"timestamps\") >= start\n",
+ " )\n",
+ " \n",
+ " q = q.group_by(\"uid\").agg([\n",
+ " pl.col(\"item_ids\").alias(\"item_ids\")\n",
+ " ]).sort(\"uid\")\n",
+ " \n",
+ " return q.collect()\n",
+ " \n",
+ " intervals = []\n",
+ " current_start = 0\n",
+ " for cutoff in time_cutoffs:\n",
+ " intervals.append((current_start, cutoff))\n",
+ " current_start = cutoff\n",
+ " # от последнего cutoff до бесконечности\n",
+ " intervals.append((current_start, None))\n",
+ " \n",
+ " for start, end in intervals:\n",
+ " subset = extract_interval(df, start, end)\n",
+ " \n",
+ " json_dict = {}\n",
+ " for user_id, item_ids in subset.iter_rows():\n",
+ " json_dict[user_id] = item_ids\n",
+ " \n",
+ " result_dicts.append(json_dict)\n",
+ " \n",
+ " if output_dir:\n",
+ " if end is not None:\n",
+ " filename = f\"inter_new_[{start}_{end}).json\"\n",
+ " else:\n",
+ " filename = f\"inter_new_[{start}_inf).json\"\n",
+ " \n",
+ " filepath = f\"{output_dir}/{filename}\"\n",
+ " with open(filepath, 'w') as f:\n",
+ " json.dump(json_dict, f, indent=2)\n",
+ " \n",
+ " print(f\"Сохранено: {filepath}\")\n",
+ " \n",
+ " return result_dicts"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "efc8b582-dd8a-4299-9c49-de906251de8a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Cutoffs: [1402444800, 1403654400, 1404864000]\n",
+ "✓ Сохранено: ../sigir/Beauty_new/splits/raw/inter_new_[0_1402444800).json\n",
+ "✓ Сохранено: ../sigir/Beauty_new/splits/raw/inter_new_[1402444800_1403654400).json\n",
+ "✓ Сохранено: ../sigir/Beauty_new/splits/raw/inter_new_[1403654400_1404864000).json\n",
+ "✓ Сохранено: ../sigir/Beauty_new/splits/raw/inter_new_[1404864000_inf).json\n",
+ "Part 0 [Base]: 22029 users\n",
+ "Part 1 [Week -6]: 1854 users\n",
+ "Part 2 [Week -4]: 1945 users\n",
+ "Part 3 [Week -2]: 1381 users\n"
+ ]
+ }
+ ],
+ "source": [
+ "global_max_time = df.select(\n",
+ " pl.col(\"timestamps\").explode().max()\n",
+ ").item()\n",
+ "\n",
+ "days_val = 14\n",
+ "window_sec = days_val * 24 * 3600 \n",
+ "\n",
+ "cutoff_test_start = global_max_time - window_sec # T - 2w\n",
+ "cutoff_val_start = global_max_time - 2 * window_sec # T - 4w\n",
+ "cutoff_gap_start = global_max_time - 3 * window_sec # T - 6w\n",
+ "\n",
+ "cutoffs = [\n",
+ " int(cutoff_gap_start),\n",
+ " int(cutoff_val_start),\n",
+ " int(cutoff_test_start)\n",
+ "]\n",
+ "\n",
+ "print(f\"Cutoffs: {cutoffs}\")\n",
+ "\n",
+ "split_files = split_session_by_timestamps(\n",
+ " df, \n",
+ " cutoffs, \n",
+ " output_dir=\"../sigir/Beauty_new/splits/raw\"\n",
+ ")\n",
+ "\n",
+ "names = [\"Base\", \"Week -6\", \"Week -4\", \"Week -2\"]\n",
+ "for i, d in enumerate(split_files):\n",
+ " print(f\"Part {i} [{names[i]}]: {len(d)} users\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "d5ba172e-b430-40a3-a4fa-64366d02a015",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "--- Creating Experiments ---\n"
+ ]
+ }
+ ],
+ "source": [
+ "def merge_and_save(parts_to_merge, dirr, output_name):\n",
+ " merged = {}\n",
+ " print(f\"Merging {len(parts_to_merge)} files into {output_name}...\")\n",
+ " \n",
+ " for part in parts_to_merge:\n",
+ " # with open(fp, 'r') as f:\n",
+ " # part = json.load(f)\n",
+ " for uid, items in part.items():\n",
+ " if uid not in merged:\n",
+ " merged[uid] = []\n",
+ " merged[uid].extend(items)\n",
+ " \n",
+ " out_path = f\"{dirr}/{output_name}\"\n",
+ " with open(out_path, 'w') as f:\n",
+ " json.dump(merged, f)\n",
+ " print(f\"Done: {out_path} (Users: {len(merged)})\")\n",
+ "\n",
+ "p0, p1, p2, p3 = split_files[0], split_files[1], split_files[2], split_files[3]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "d116b7e0-9bf9-4104-86a0-69788a70cc14",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Merging 2 files into exp_4_inter_tiger_train.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json (Users: 22129)\n",
+ "Merging 2 files into exp_4.1_inter_semantics_train.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json (Users: 22129)\n",
+ "Merging 1 files into exp_4.2_inter_semantics_train_short.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.2_inter_semantics_train_short.json (Users: 22029)\n",
+ "Merging 3 files into exp_4.3_inter_semantics_train_leak.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.3_inter_semantics_train_leak.json (Users: 22265)\n",
+ "Merging 1 files into test_set.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/test_set.json (Users: 1381)\n",
+ "Merging 1 files into valid_skip_set.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/valid_skip_set.json (Users: 1945)\n",
+ "\n",
+ "All done!\n"
+ ]
+ }
+ ],
+ "source": [
+ "EXP_DIR = \"../sigir/Beauty_new/splits/exp_data\"\n",
+ "\n",
+ "# Tiger: P0+P1\n",
+ "merge_and_save([p0, p1], EXP_DIR, \"exp_4_inter_tiger_train.json\")\n",
+ "\n",
+ "# 1. Exp 4.1\n",
+ "# Semantics: P0+P1 (Всё кроме валида и теста)\n",
+ "merge_and_save([p0, p1], EXP_DIR, \"exp_4.1_inter_semantics_train.json\")\n",
+ "\n",
+ "# 2. Exp 4.2\n",
+ "# Semantics: P0 (Короче на неделю, без P1)\n",
+ "merge_and_save([p0], EXP_DIR, \"exp_4.2_inter_semantics_train_short.json\")\n",
+ "\n",
+ "# 3. Exp 4.3\n",
+ "# Semantics: P0+P1+P2 (Видит валидацию)\n",
+ "merge_and_save([p0, p1, p2], EXP_DIR, \"exp_4.3_inter_semantics_train_leak.json\")\n",
+ "\n",
+ "# 4. Test Set (тест всех моделей)\n",
+ "merge_and_save([p3], EXP_DIR, \"test_set.json\")\n",
+ "\n",
+ "# 4. Valid Set (пропуск, имитируется разница трейна и теста чтобы потом дообучать)\n",
+ "merge_and_save([p2], EXP_DIR, \"valid_skip_set.json\")\n",
+ "\n",
+ "print(\"\\nAll done!\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "9ae1d1e5-567d-471a-8f83-4039ecacc8d2",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Merging 4 files into all_set.json...\n",
+ "✓ Done: ../sigir/Beauty_new/splits/exp_data/all_set.json (Users: 22363)\n"
+ ]
+ }
+ ],
+ "source": [
+ "merge_and_save([p0, p1, p2, p3], EXP_DIR, \"all_set.json\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "328de16c-f61d-45be-8a72-5f0eaef612e8",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Проверка Train сетов (должны быть префиксами):\n",
+ "✅ [ПРЕФИКСЫ] Все 22129 массивов ОК. Полных совпадений: 19410\n",
+ "✅ [ПРЕФИКСЫ] Все 22029 массивов ОК. Полных совпадений: 18191\n",
+ "✅ [ПРЕФИКСЫ] Все 22265 массивов ОК. Полных совпадений: 20982\n",
+ "✅ [ПРЕФИКСЫ] Все 22129 массивов ОК. Полных совпадений: 19410\n",
+ "\n",
+ "Проверка Test сета (должен быть суффиксом):\n",
+ "✅ [СУФФИКСЫ] Все 1381 массивов ОК. Полных совпадений: 98\n",
+ "\n",
+ "(Контроль) Проверка Test сета как префикса (должна упасть):\n",
+ "❌ [ПРЕФИКСЫ] Найдено 1283 ошибок.\n"
+ ]
+ }
+ ],
+ "source": [
+ "with open(\"../data/Beauty/inter_new.json\", 'r') as f:\n",
+ " old_inter_new = json.load(f)\n",
+ "\n",
+ "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json\", 'r') as ff:\n",
+ " first_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.2_inter_semantics_train_short.json\", 'r') as ff:\n",
+ " second_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.3_inter_semantics_train_leak.json\", 'r') as ff:\n",
+ " third_sem = json.load(ff)\n",
+ " \n",
+ "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json\", 'r') as ff:\n",
+ " tiger_sem = json.load(ff)\n",
+ "\n",
+ "with open(\"../sigir/Beauty_new/splits/exp_data/test_set.json\", 'r') as ff:\n",
+ " test_sem = json.load(ff)\n",
+ "\n",
+ "def check_prefix_match(full_data, subset_data, check_suffix=False):\n",
+ " \"\"\"\n",
+ " check_suffix=True включит режим проверки суффиксов (для теста).\n",
+ " \"\"\"\n",
+ " mismatch_count = 0\n",
+ " full_match_count = 0\n",
+ " \n",
+ " for user, sub_items in subset_data.items():\n",
+ " if user not in full_data:\n",
+ " print(f\"Юзер {user} не найден в исходном файле!\")\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " full_items = full_data[user]\n",
+ " \n",
+ " if not check_suffix:\n",
+ " if len(sub_items) > len(full_items):\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " if full_items[:len(sub_items)] == sub_items:\n",
+ " if len(full_items) == len(sub_items):\n",
+ " full_match_count += 1\n",
+ " else:\n",
+ " mismatch_count += 1\n",
+ " \n",
+ " else:\n",
+ " if len(sub_items) > len(full_items):\n",
+ " mismatch_count += 1\n",
+ " continue\n",
+ " \n",
+ " if full_items[-len(sub_items):] == sub_items:\n",
+ " if len(full_items) == len(sub_items):\n",
+ " full_match_count += 1\n",
+ " else:\n",
+ " mismatch_count += 1\n",
+ "\n",
+ " mode = \"СУФФИКСЫ\" if check_suffix else \"ПРЕФИКСЫ\"\n",
+ " \n",
+ " if mismatch_count == 0:\n",
+ " print(f\"[{mode}] Все {len(subset_data)} массивов ОК. Полных совпадений: {full_match_count}\")\n",
+ " else:\n",
+ " print(f\"[{mode}] Найдено {mismatch_count} ошибок.\")\n",
+ "\n",
+ "print(\"Проверка Train сетов (должны быть префиксами):\")\n",
+ "check_prefix_match(old_inter_new, first_sem)\n",
+ "check_prefix_match(old_inter_new, second_sem)\n",
+ "check_prefix_match(old_inter_new, third_sem)\n",
+ "check_prefix_match(old_inter_new, tiger_sem)\n",
+ "\n",
+ "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n",
+ "check_prefix_match(old_inter_new, test_sem, check_suffix=True)\n",
+ "\n",
+ "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n",
+ "check_prefix_match(old_inter_new, test_sem, check_suffix=False)\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}