diff --git a/scripts/plum/cooc_data.py b/scripts/plum/cooc_data.py index b11e6f0..50f2bdd 100644 --- a/scripts/plum/cooc_data.py +++ b/scripts/plum/cooc_data.py @@ -13,21 +13,15 @@ class CoocMappingDataset: def __init__( self, train_sampler, - validation_sampler, - test_sampler, num_items, - max_sequence_length, cooccur_counter_mapping=None ): self._train_sampler = train_sampler - self._validation_sampler = validation_sampler - self._test_sampler = test_sampler self._num_items = num_items - self._max_sequence_length = max_sequence_length self._cooccur_counter_mapping = cooccur_counter_mapping @classmethod - def create(cls, inter_json_path, max_sequence_length, sampler_type, window_size): + def create(cls, inter_json_path, window_size): max_item_id = 0 train_dataset, validation_dataset, test_dataset = [], [], [] @@ -43,31 +37,59 @@ def create(cls, inter_json_path, max_sequence_length, sampler_type, window_size) 'user.ids': [user_id], 'item.ids': item_ids[:-2], }) - validation_dataset.append({ - 'user.ids': [user_id], - 'item.ids': item_ids[:-1], - }) - test_dataset.append({ - 'user.ids': [user_id], - 'item.ids': item_ids, - }) cooccur_counter_mapping = cls.build_cooccur_counter_mapping(train_dataset, window_size=window_size) logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items but max_item_id is {max_item_id}') train_sampler = train_dataset - validation_sampler = validation_dataset - test_sampler = test_dataset return cls( train_sampler=train_sampler, - validation_sampler=validation_sampler, - test_sampler=test_sampler, num_items=max_item_id + 1, - max_sequence_length=max_sequence_length, cooccur_counter_mapping=cooccur_counter_mapping ) + @classmethod + def create_from_split_part( + cls, + train_inter_json_path, + window_size + ): + + max_item_id = 0 + train_dataset = [] + + with open(train_inter_json_path, 'r') as f: + train_interactions = json.load(f) + + # Обрабатываем TRAIN + for user_id_str, item_ids in train_interactions.items(): + user_id = int(user_id_str) + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids, + }) + + logger.debug(f'Train: {len(train_dataset)} users') + logger.debug(f'Max item ID: {max_item_id}') + + cooccur_counter_mapping = cls.build_cooccur_counter_mapping( + train_dataset, + window_size=window_size + ) + + logger.debug(f'Computed window-based co-occurrence mapping for {len(cooccur_counter_mapping)} items') + + return cls( + train_sampler=train_dataset, + num_items=max_item_id + 1, + cooccur_counter_mapping=cooccur_counter_mapping + ) + + @staticmethod def build_cooccur_counter_mapping(train_dataset, window_size): #TODO передавать время и по нему строить окно cooccur_counts = defaultdict(Counter) @@ -80,16 +102,6 @@ def build_cooccur_counter_mapping(train_dataset, window_size): #TODO перед cooccur_counts[item_i][items[j]] += 1 return cooccur_counts - def get_datasets(self): - return self._train_sampler, self._validation_sampler, self._test_sampler - - @property - def num_items(self): - return self._num_items - - @property - def max_sequence_length(self): - return self._max_sequence_length @property def cooccur_counter_mapping(self): diff --git a/scripts/plum/infer_default.py b/scripts/plum/infer_default.py index af8df34..b15fb6d 100644 --- a/scripts/plum/infer_default.py +++ b/scripts/plum/infer_default.py @@ -12,8 +12,18 @@ from data import EmbeddingDataset, ProcessEmbeddings from models import PlumRQVAE -from transforms import AddWeightedCooccurrenceEmbeddings -from cooc_data import CoocMappingDataset + +# ПУТИ +IREC_PATH = '/home/jovyan/IRec/' +EMBEDDINGS_PATH = '/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' +MODEL_PATH = '/home/jovyan/IRec/checkpoints/4-1_plum_rqvae_beauty_ws_2_best_0.0051.pth' +RESULTS_PATH = os.path.join(IREC_PATH, 'results') + +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}' + +# ОСТАЛЬНОЕ SEED_VALUE = 42 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') @@ -26,29 +36,16 @@ NUM_CODEBOOKS = 3 BETA = 0.25 -MODEL_PATH = '/home/jovyan/IRec/checkpoints/test_plum_rqvae_beauty_ws_2_best_0.0054.pth' -WINDOW_SIZE = 2 - -EXPERIMENT_NAME = f'test_plum_rqvae_beauty_ws_{WINDOW_SIZE}' - -IREC_PATH = '/home/jovyan/IRec/' def main(): fix_random_seed(SEED_VALUE) - data = CoocMappingDataset.create( - inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'), - max_sequence_length=20, - sampler_type='sasrec', - window_size=WINDOW_SIZE - ) - dataset = EmbeddingDataset( - data_path='/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl' + data_path=EMBEDDINGS_PATH ) - + item_id_to_embedding = {} all_item_ids = [] for idx in range(len(dataset)): @@ -57,15 +54,12 @@ def main(): item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) all_item_ids.append(item_id) - add_cooc_transform = AddWeightedCooccurrenceEmbeddings( - data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) - dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, - ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform) + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) model = PlumRQVAE( input_dim=INPUT_DIM, @@ -106,8 +100,8 @@ def main(): cb.Logger().every_num_steps(len(dataloader)), cb.InferenceSaver( - metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, - save_path=f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), format='json' ) ] @@ -125,9 +119,9 @@ def main(): from collections import defaultdict import numpy as np - with open(f'/home/jovyan/IRec/results/{EXPERIMENT_NAME}_clusters.json', 'r') as f: + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: mappings = json.load(f) - + inter = {} sem_2_ids = defaultdict(list) for mapping in mappings: @@ -143,8 +137,8 @@ def main(): inter[item_id].append(collision_solver) for i in range(len(inter[item_id])): inter[item_id][i] += CODEBOOK_SIZE * i - - with open(os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + + with open(os.path.join(RESULTS_PATH, f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: json.dump(inter, f, indent=2) diff --git a/scripts/plum/train_plum.py b/scripts/plum/train_plum.py index 5a00bc3..ffa9e43 100644 --- a/scripts/plum/train_plum.py +++ b/scripts/plum/train_plum.py @@ -41,8 +41,6 @@ def main(): data = CoocMappingDataset.create( inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter_new.json'), - max_sequence_length=20, - sampler_type='sasrec', window_size=WINDOW_SIZE ) diff --git a/scripts/plum/train_plum_timestamp_based.py b/scripts/plum/train_plum_timestamp_based.py new file mode 100644 index 0000000..e755d95 --- /dev/null +++ b/scripts/plum/train_plum_timestamp_based.py @@ -0,0 +1,168 @@ +from loguru import logger +import os + +import torch + +import pickle + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDataset, ProcessEmbeddings +from models import PlumRQVAE +from transforms import AddWeightedCooccurrenceEmbeddings +from cooc_data import CoocMappingDataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 500 +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 +WINDOW_SIZE = 2 + +EXPERIMENT_NAME = f'4-1_plum_rqvae_beauty_ws_{WINDOW_SIZE}' +INTER_TRAIN_PATH = "/home/jovyan/IRec/sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json" +EMBEDDINGS_PATH = "/home/jovyan/tiger/data/Beauty/default_content_embeddings.pkl" +IREC_PATH = '../../' + +def main(): + fix_random_seed(SEED_VALUE) + + data = CoocMappingDataset.create_from_split_part( + train_inter_json_path=INTER_TRAIN_PATH, + window_size=WINDOW_SIZE + ) + + dataset = EmbeddingDataset( + data_path=EMBEDDINGS_PATH + ) + + item_id_to_embedding = {} + all_item_ids = [] + for idx in range(len(dataset)): + sample = dataset[idx] + item_id = int(sample['item_id']) + item_id_to_embedding[item_id] = torch.tensor(sample['embedding']) + all_item_ids.append(item_id) + + add_cooc_transform = AddWeightedCooccurrenceEmbeddings( + data.cooccur_counter_mapping, item_id_to_embedding, all_item_ids) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map( + ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding']) + ).map(add_cooc_transform).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).map(add_cooc_transform) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = PlumRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + contrastive_loss_weight=1.0, + temperature=1.0 + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/con_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'con_loss': model_outputs['con_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/con_loss': cb.MeanAccumulator() + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger/data.py b/scripts/tiger/data.py index 188993a..760fe8c 100644 --- a/scripts/tiger/data.py +++ b/scripts/tiger/data.py @@ -28,6 +28,113 @@ def __init__( self._num_items = num_items self._max_sequence_length = max_sequence_length + @classmethod + def create_timestamp_based( + cls, + train_json_path, + validation_json_path, + test_json_path, + max_sequence_length, + sampler_type, + min_sample_len=2, + is_extended=False + ): + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + with open(train_json_path, 'r') as f: + train_data = json.load(f) + with open(validation_json_path, 'r') as f: + validation_data = json.load(f) + with open(test_json_path, 'r') as f: + test_data = json.load(f) + + all_users = set(train_data.keys()) | set(validation_data.keys()) | set(test_data.keys()) + + for user_id_str in all_users: + user_id = int(user_id_str) + + train_items = train_data.get(user_id_str, []) + validation_items = validation_data.get(user_id_str, []) + test_items = test_data.get(user_id_str, []) + + full_sequence = train_items + validation_items + test_items + if full_sequence: + max_item_id = max(max_item_id, max(full_sequence)) + + assert len(full_sequence) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(full_sequence)} items' + + if is_extended: + # sample = [1, 2] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + # sample = [1, 2, 3, 4, 5, 6, 7] + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + for prefix_length in range(min_sample_len, len(train_items) + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': train_items[:prefix_length], + }) + else: + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': train_items, + }) + + # разворачиваем каждый айтем из валидации в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + + current_history = train_items.copy() + for item in validation_items: + # эвал датасет сам отрезает таргет потом + sample_sequence = current_history + [item] + + if len(sample_sequence) >= min_sample_len: + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': sample_sequence, + }) + current_history.append(item) + + # разворачиваем каждый айтем из теста в отдельный сэмпл + # Пример: Train=[1,2], Valid=[3,4], Test=[5, 6] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + current_history = train_items + validation_items + + for item in test_items: + # эвал датасет сам отрезает таргет потом + sample_sequence = current_history + [item] + + if len(sample_sequence) >= min_sample_len: + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': sample_sequence, + }) + + current_history.append(item) + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + @classmethod def create(cls, inter_json_path, max_sequence_length, sampler_type, is_extended=False): max_item_id = 0 diff --git a/scripts/tiger/train.py b/scripts/tiger/train.py index f436dd4..1a2d347 100644 --- a/scripts/tiger/train.py +++ b/scripts/tiger/train.py @@ -14,10 +14,23 @@ from data import ArrowBatchDataset from models import TigerModel, CorrectItemsLogitsProcessor + +# ПУТИ +IREC_PATH = '../../' +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_plum_rqvae_beauty_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_eval_batches/') + +TENSORBOARD_LOGDIR = os.path.join(IREC_PATH, 'tensorboard_logs') +CHECKPOINTS_DIR = os.path.join(IREC_PATH, 'checkpoints') + +EXPERIMENT_NAME = 'tiger_beauty_4-1_plum_ws_2_dp_0.2' + +# ОСТАЛЬНОЕ SEED_VALUE = 42 DEVICE = 'cuda' -EXPERIMENT_NAME = 'tiger_beauty' NUM_EPOCHS = 300 MAX_SEQ_LEN = 20 TRAIN_BATCH_SIZE = 256 @@ -30,13 +43,12 @@ NUM_LAYERS = 4 FEEDFORWARD_DIM = 1024 KV_DIM = 64 -DROPOUT = 0.1 +DROPOUT = 0.2 NUM_BEAMS = 30 TOP_K = 20 NUM_CODEBOOKS = 4 -LR = 3e-4 +LR = 0.0001 -IREC_PATH = '../../' torch.set_float32_matmul_precision('high') torch._dynamo.config.capture_scalar_outputs = True @@ -48,30 +60,30 @@ def main(): fix_random_seed(SEED_VALUE) - with open(os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json'), 'r') as f: + with open(SEMANTIC_MAPPING_PATH, 'r') as f: mappings = json.load(f) - + train_dataloader = DataLoader( ArrowBatchDataset( - os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/'), - device='cpu', + TRAIN_BATCHES_DIR, + device='cpu', preload=True ), - batch_size=1, - shuffle=True, + batch_size=1, + shuffle=True, num_workers=0, - pin_memory=True, + pin_memory=True, collate_fn=Collate() ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) valid_dataloder = ArrowBatchDataset( - os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/'), + VALID_BATCHES_DIR, device=DEVICE, preload=True ) eval_dataloder = ArrowBatchDataset( - os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/'), + EVAL_BATCHES_DIR, device=DEVICE, preload=True ) @@ -177,22 +189,22 @@ def main(): ), ], ).every_num_steps(EPOCH_NUM_STEPS), - + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), - cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=TENSORBOARD_LOGDIR), cb.EarlyStopping( - metric='eval/ndcg@20', + metric='eval/ndcg@20', patience=40, minimize=False, - model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + model_path=os.path.join(CHECKPOINTS_DIR, EXPERIMENT_NAME) ).every_num_steps(EPOCH_NUM_STEPS) # cb.Profiler( # wait=10, # warmup=10, # active=10, - # logdir=os.path.join(IREC_PATH, 'tensorboard_logs') + # logdir=TENSORBOARD_LOGDIR # ), # cb.StopAfterNumSteps(40) diff --git a/scripts/tiger/varka.py b/scripts/tiger/varka.py index ed47595..4dc3e02 100644 --- a/scripts/tiger/varka.py +++ b/scripts/tiger/varka.py @@ -15,6 +15,20 @@ from data import Dataset + + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_PATH = os.path.join(IREC_PATH, 'data/Beauty/inter.json') +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/') + + +# ОСТАЛЬНОЕ + SEED_VALUE = 42 DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') @@ -32,8 +46,6 @@ DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, -IREC_PATH = '../../' - class TigerProcessing(Transform): def __call__(self, batch): @@ -42,12 +54,12 @@ def __call__(self, batch): input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? - input_semantic_ids = np.concat([ + input_semantic_ids = np.concatenate([ input_semantic_ids, NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] ], axis=-1) - attention_mask = np.concat([ + attention_mask = np.concatenate([ attention_mask, np.ones((batch_size, 1), dtype=attention_mask.dtype) ], axis=-1) @@ -56,7 +68,7 @@ def __call__(self, batch): batch['input.mask'] = attention_mask target_semantic_ids = batch['labels.semantic.padded'] - target_semantic_ids = np.concat([ + target_semantic_ids = np.concatenate([ np.ones( (batch_size, 1), dtype=np.int64, @@ -73,7 +85,7 @@ class ToMasked(Transform): def __init__(self, prefix, is_right_aligned=False): self._prefix = prefix self._is_right_aligned = is_right_aligned - + def __call__(self, batch): data = batch[f'{self._prefix}.ids'] lengths = batch[f'{self._prefix}.length'] @@ -92,7 +104,7 @@ def __call__(self, batch): (batch_size, max_sequence_length, data.shape[-1]), dtype=data.dtype ) # (batch_size, max_seq_len, emb_dim) - + mask = np.arange(max_sequence_length)[None] < lengths[:, None] if self._is_right_aligned: @@ -117,10 +129,10 @@ def __init__(self, mapping, names=[]): data.append(mapping[str(i)]) self._mapping_tensor = torch.tensor(data, dtype=torch.long) self._semantic_length = self._mapping_tensor.shape[-1] - + def __call__(self, batch): for name in self._names: - if f'{name}.ids' in batch: + if f'{name}.ids' in batch: ids = batch[f'{name}.ids'] lengths = batch[f'{name}.length'] assert ids.min() >= 0 @@ -135,7 +147,7 @@ class UserHashing(Transform): def __init__(self, hash_size): super().__init__() self._hash_size = hash_size - + def __call__(self, batch): batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) return batch @@ -144,7 +156,7 @@ def __call__(self, batch): def save_batches_to_arrow(batches, output_dir): output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=False) - + for batch_idx, batch in enumerate(batches): length_groups = defaultdict(dict) metadata_groups = defaultdict(dict) @@ -164,7 +176,7 @@ def save_batches_to_arrow(batches, output_dir): else: # >2D массив - flatten и сохраняем shape length_groups[length][key] = value.flatten() - + for length, fields in length_groups.items(): arrow_dict = {} for k, v in fields.items(): @@ -173,11 +185,11 @@ def save_batches_to_arrow(batches, output_dir): arrow_dict[k] = pa.array(v) else: arrow_dict[k] = pa.array(v) - + table = pa.table(arrow_dict) if length in metadata_groups: table = table.replace_schema_metadata(metadata_groups[length]) - + feather.write_feather( table, output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", @@ -186,7 +198,7 @@ def save_batches_to_arrow(batches, output_dir): # arrow_dict = {k: pa.array(v) for k, v in fields.items()} # table = pa.table(arrow_dict) - + # feather.write_feather( # table, # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", @@ -196,15 +208,15 @@ def save_batches_to_arrow(batches, output_dir): def main(): data = Dataset.create( - inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter.json'), + inter_json_path=INTERACTIONS_PATH, max_sequence_length=MAX_SEQ_LEN, sampler_type='tiger', is_extended=True ) - with open(os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json'), 'r') as f: + with open(SEMANTIC_MAPPING_PATH, 'r') as f: mappings = json.load(f) - + train_dataset, valid_dataset, eval_dataset = data.get_datasets() train_dataloader = DataLoader( @@ -219,7 +231,7 @@ def main(): .map(ToMasked('item.semantic', is_right_aligned=True)) \ .map(ToMasked('labels.semantic', is_right_aligned=True)) \ .map(TigerProcessing()) - + valid_dataloader = DataLoader( dataset=valid_dataset, batch_size=VALID_BATCH_SIZE, @@ -251,17 +263,18 @@ def main(): train_batches = [] for train_batch in train_dataloader: train_batches.append(train_batch) - save_batches_to_arrow(train_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/')) - + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + valid_batches = [] for valid_batch in valid_dataloader: valid_batches.append(valid_batch) - save_batches_to_arrow(valid_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/')) - + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + eval_batches = [] for eval_batch in eval_dataloader: eval_batches.append(eval_batch) - save_batches_to_arrow(eval_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/')) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + if __name__ == '__main__': diff --git a/scripts/tiger/varka_timestamp_based.py b/scripts/tiger/varka_timestamp_based.py new file mode 100644 index 0000000..11343ea --- /dev/null +++ b/scripts/tiger/varka_timestamp_based.py @@ -0,0 +1,287 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +from data import Dataset + + + +# ПУТИ + +IREC_PATH = '../../' +INTERACTIONS_TRAIN_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json') +INTERACTIONS_VALID_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/valid_skip_set.json') +INTERACTIONS_TEST_PATH = os.path.join(IREC_PATH, 'sigir/Beauty_new/splits/exp_data/test_set.json') + +SEMANTIC_MAPPING_PATH = os.path.join(IREC_PATH, 'results_sigir/4-1_plum_rqvae_beauty_ws_2_clusters_colisionless.json') +TRAIN_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_train_batches/') +VALID_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_valid_batches/') +EVAL_BATCHES_DIR = os.path.join(IREC_PATH, 'data/Beauty/tiger_4-1_eval_batches/') + + +# ОСТАЛЬНОЕ + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 2000 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concatenate([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concatenate([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concatenate([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + data = [] + for i in range(len(mapping)): + data.append(mapping[str(i)]) + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + data = Dataset.create_timestamp_based( + train_json_path=INTERACTIONS_TRAIN_PATH, + validation_json_path=INTERACTIONS_VALID_PATH, + test_json_path=INTERACTIONS_TEST_PATH, + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + min_sample_len=2, + is_extended=True + ) + + with open(SEMANTIC_MAPPING_PATH, 'r') as f: + mappings = json.load(f) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, TRAIN_BATCHES_DIR) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, VALID_BATCHES_DIR) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, EVAL_BATCHES_DIR) + + + +if __name__ == '__main__': + main() diff --git a/sigir/DatasetProcessing.ipynb b/sigir/DatasetProcessing.ipynb new file mode 100644 index 0000000..09b8d21 --- /dev/null +++ b/sigir/DatasetProcessing.ipynb @@ -0,0 +1,727 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "3bdb292f", + "metadata": {}, + "outputs": [], + "source": [ + "from collections import defaultdict\n", + "\n", + "import json\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pickle\n", + "import polars as pl\n", + "\n", + "from transformers import LlamaModel, LlamaTokenizer\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from tqdm import tqdm as tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "66d9b312", + "metadata": {}, + "outputs": [], + "source": [ + "interactions_dataset_path = '../data/Beauty/Beauty_5.json'\n", + "metadata_path = '../data/Beauty/metadata.json'\n", + "\n", + "interactions_output_json_path = '../data/Beauty_new/inter_new.json'\n", + "interactions_output_parquet_path = '../data/Beauty_new/inter_new.parquet'\n", + "embeddings_output_path = '../data/Beauty_new/content_embeddings.pkl'\n", + "item_ids_mapping_output_path = '../data/Beauty_new/item_ids_mapping.json'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6ed4dffb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of events: 198502\n" + ] + } + ], + "source": [ + "df = defaultdict(list)\n", + "\n", + "with open(interactions_dataset_path, 'r') as f:\n", + " for line in f.readlines():\n", + " review = json.loads(line)\n", + " df['user_id'].append(review['reviewerID'])\n", + " df['item_id'].append(review['asin'])\n", + " df['timestamp'].append(review['unixReviewTime'])\n", + "\n", + "print(f'Number of events: {len(df[\"user_id\"])}')\n", + "\n", + "df = pl.from_dict(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c26746c4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
user_iditem_idtimestamp
strstri64
"A1YJEY40YUW4SE""7806397051"1391040000
"A60XNB876KYML""7806397051"1397779200
"A3G6XNM240RMWA""7806397051"1378425600
"A1PQFP6SAJ6D80""7806397051"1386460800
"A38FVHZTNQ271F""7806397051"1382140800
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌────────────────┬────────────┬────────────┐\n", + "│ user_id ┆ item_id ┆ timestamp │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ str ┆ str ┆ i64 │\n", + "╞════════════════╪════════════╪════════════╡\n", + "│ A1YJEY40YUW4SE ┆ 7806397051 ┆ 1391040000 │\n", + "│ A60XNB876KYML ┆ 7806397051 ┆ 1397779200 │\n", + "│ A3G6XNM240RMWA ┆ 7806397051 ┆ 1378425600 │\n", + "│ A1PQFP6SAJ6D80 ┆ 7806397051 ┆ 1386460800 │\n", + "│ A38FVHZTNQ271F ┆ 7806397051 ┆ 1382140800 │\n", + "└────────────────┴────────────┴────────────┘" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "adcf5713", + "metadata": {}, + "outputs": [], + "source": [ + "filtered_df = df.clone()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c0bbf9ba", + "metadata": {}, + "outputs": [], + "source": [ + "# Processing dataset to get core-5 state in case full dataset is provided\n", + "is_changed = True\n", + "threshold = 5\n", + "good_users = set()\n", + "good_items = set()\n", + "\n", + "while is_changed:\n", + " user_counts = filtered_df.group_by('user_id').agg(\n", + " pl.len().alias('user_count'),\n", + " )\n", + " item_counts = filtered_df.group_by('item_id').agg(\n", + " pl.len().alias('item_count'),\n", + " )\n", + "\n", + " good_users = user_counts.filter(pl.col('user_count') >= threshold).select(\n", + " 'user_id',\n", + " )\n", + " good_items = item_counts.filter(pl.col('item_count') >= threshold).select(\n", + " 'item_id',\n", + " )\n", + "\n", + " old_size = len(filtered_df)\n", + "\n", + " new_df = filtered_df.join(good_users, on='user_id', how='inner')\n", + " new_df = new_df.join(good_items, on='item_id', how='inner')\n", + "\n", + " new_size = len(new_df)\n", + "\n", + " filtered_df = new_df\n", + " is_changed = old_size != new_size\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "218a9348", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
user_iditem_idtimestamp
i64i64i64
001391040000
101397779200
201378425600
301386460800
401382140800
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌─────────┬─────────┬────────────┐\n", + "│ user_id ┆ item_id ┆ timestamp │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ i64 ┆ i64 │\n", + "╞═════════╪═════════╪════════════╡\n", + "│ 0 ┆ 0 ┆ 1391040000 │\n", + "│ 1 ┆ 0 ┆ 1397779200 │\n", + "│ 2 ┆ 0 ┆ 1378425600 │\n", + "│ 3 ┆ 0 ┆ 1386460800 │\n", + "│ 4 ┆ 0 ┆ 1382140800 │\n", + "└─────────┴─────────┴────────────┘" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "unique_values = filtered_df[\"user_id\"].unique(maintain_order=True).to_list()\n", + "user_ids_mapping = {value: i for i, value in enumerate(unique_values)}\n", + "\n", + "filtered_df = filtered_df.with_columns(\n", + " pl.col(\"user_id\").replace_strict(user_ids_mapping)\n", + ")\n", + "\n", + "unique_values = filtered_df[\"item_id\"].unique(maintain_order=True).to_list()\n", + "item_ids_mapping = {value: i for i, value in enumerate(unique_values)}\n", + "\n", + "filtered_df = filtered_df.with_columns(\n", + " pl.col(\"item_id\").replace_strict(item_ids_mapping)\n", + ")\n", + "\n", + "filtered_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "34604fe6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 2)
old_item_idnew_item_id
stri64
"7806397051"0
"9759091062"1
"9788072216"2
"9790790961"3
"9790794231"4
" + ], + "text/plain": [ + "shape: (5, 2)\n", + "┌─────────────┬─────────────┐\n", + "│ old_item_id ┆ new_item_id │\n", + "│ --- ┆ --- │\n", + "│ str ┆ i64 │\n", + "╞═════════════╪═════════════╡\n", + "│ 7806397051 ┆ 0 │\n", + "│ 9759091062 ┆ 1 │\n", + "│ 9788072216 ┆ 2 │\n", + "│ 9790790961 ┆ 3 │\n", + "│ 9790794231 ┆ 4 │\n", + "└─────────────┴─────────────┘" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "item_ids_mapping_df = pl.from_dict({\n", + " 'old_item_id': list(item_ids_mapping.keys()),\n", + " 'new_item_id': list(item_ids_mapping.values())\n", + "})\n", + "item_ids_mapping_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "99b54db807b9495c", + "metadata": {}, + "outputs": [], + "source": [ + "with open(item_ids_mapping_output_path, 'w') as f:\n", + " json.dump({str(k): v for k, v in item_ids_mapping.items()}, f)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "6017e65c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
user_iditem_idtimestamp
i64i64i64
001391040000
101397779200
201378425600
301386460800
401382140800
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌─────────┬─────────┬────────────┐\n", + "│ user_id ┆ item_id ┆ timestamp │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ i64 ┆ i64 │\n", + "╞═════════╪═════════╪════════════╡\n", + "│ 0 ┆ 0 ┆ 1391040000 │\n", + "│ 1 ┆ 0 ┆ 1397779200 │\n", + "│ 2 ┆ 0 ┆ 1378425600 │\n", + "│ 3 ┆ 0 ┆ 1386460800 │\n", + "│ 4 ┆ 0 ┆ 1382140800 │\n", + "└─────────┴─────────┴────────────┘" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "filtered_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9efd1983", + "metadata": {}, + "outputs": [], + "source": [ + "filtered_df = filtered_df.sort([\"user_id\", \"timestamp\"])\n", + "\n", + "grouped_filtered_df = filtered_df.group_by(\"user_id\", maintain_order=True).agg(pl.all())" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "fd51c525", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 2)
old_item_idnew_item_id
stri64
"7806397051"0
"9759091062"1
"9788072216"2
"9790790961"3
"9790794231"4
" + ], + "text/plain": [ + "shape: (5, 2)\n", + "┌─────────────┬─────────────┐\n", + "│ old_item_id ┆ new_item_id │\n", + "│ --- ┆ --- │\n", + "│ str ┆ i64 │\n", + "╞═════════════╪═════════════╡\n", + "│ 7806397051 ┆ 0 │\n", + "│ 9759091062 ┆ 1 │\n", + "│ 9788072216 ┆ 2 │\n", + "│ 9790790961 ┆ 3 │\n", + "│ 9790794231 ┆ 4 │\n", + "└─────────────┴─────────────┘" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "item_ids_mapping_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "8b0821da", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 3)
user_iditem_idtimestamp
i64list[i64]list[i64]
0[6845, 7872, … 0][1318896000, 1318896000, … 1391040000]
1[815, 10405, … 232][1392422400, 1396224000, … 1397779200]
2[6049, 0, … 6608][1378425600, 1378425600, … 1400284800]
3[5521, 5160, … 0][1379116800, 1380931200, … 1386460800]
4[0, 10469, … 11389][1382140800, 1383523200, … 1388966400]
" + ], + "text/plain": [ + "shape: (5, 3)\n", + "┌─────────┬─────────────────────┬─────────────────────────────────┐\n", + "│ user_id ┆ item_id ┆ timestamp │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ list[i64] ┆ list[i64] │\n", + "╞═════════╪═════════════════════╪═════════════════════════════════╡\n", + "│ 0 ┆ [6845, 7872, … 0] ┆ [1318896000, 1318896000, … 139… │\n", + "│ 1 ┆ [815, 10405, … 232] ┆ [1392422400, 1396224000, … 139… │\n", + "│ 2 ┆ [6049, 0, … 6608] ┆ [1378425600, 1378425600, … 140… │\n", + "│ 3 ┆ [5521, 5160, … 0] ┆ [1379116800, 1380931200, … 138… │\n", + "│ 4 ┆ [0, 10469, … 11389] ┆ [1382140800, 1383523200, … 138… │\n", + "└─────────┴─────────────────────┴─────────────────────────────────┘" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "grouped_filtered_df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "dc222d59", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Users count: 22363\n", + "Items count: 12101\n", + "Actions count: 198502\n", + "Avg user history len: 8.876358270357287\n" + ] + } + ], + "source": [ + "print('Users count:', filtered_df.select('user_id').unique().shape[0])\n", + "print('Items count:', filtered_df.select('item_id').unique().shape[0])\n", + "print('Actions count:', filtered_df.shape[0])\n", + "print('Avg user history len:', np.mean(list(map(lambda x: x[0], grouped_filtered_df.select(pl.col('item_id').list.len()).rows()))))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a272855d-84b2-4414-ba9f-62647e1151cf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "shape: (5, 3)\n", + "┌─────┬─────────────────────┬─────────────────────────────────┐\n", + "│ uid ┆ item_ids ┆ timestamps │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ i64 ┆ list[i64] ┆ list[i64] │\n", + "╞═════╪═════════════════════╪═════════════════════════════════╡\n", + "│ 0 ┆ [6845, 7872, … 0] ┆ [1318896000, 1318896000, … 139… │\n", + "│ 1 ┆ [815, 10405, … 232] ┆ [1392422400, 1396224000, … 139… │\n", + "│ 2 ┆ [6049, 0, … 6608] ┆ [1378425600, 1378425600, … 140… │\n", + "│ 3 ┆ [5521, 5160, … 0] ┆ [1379116800, 1380931200, … 138… │\n", + "│ 4 ┆ [0, 10469, … 11389] ┆ [1382140800, 1383523200, … 138… │\n", + "└─────┴─────────────────────┴─────────────────────────────────┘\n" + ] + } + ], + "source": [ + "inter_new = grouped_filtered_df.select([\n", + " pl.col(\"user_id\").alias(\"uid\"),\n", + " pl.col(\"item_id\").alias(\"item_ids\"),\n", + " pl.col(\"timestamp\").alias(\"timestamps\")\n", + "])\n", + "\n", + "print(inter_new.head())" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "de5a853a-8ee2-42dd-a71a-6cc6f90d526c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Файл успешно сохранен: ../data/Beauty_new/inter_new.parquet\n" + ] + } + ], + "source": [ + "output_path_parquet = interactions_output_parquet_path\n", + "inter_new.write_parquet(output_path_parquet)\n", + "\n", + "print(f\"Файл успешно сохранен: {output_path_parquet}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d07a2e91", + "metadata": {}, + "outputs": [], + "source": [ + "json_data = {}\n", + "for user_id, item_ids, _ in grouped_filtered_df.iter_rows():\n", + " json_data[user_id] = item_ids\n", + "\n", + "with open(interactions_output_json_path, 'w') as f:\n", + " json.dump(json_data, f, indent=2)" + ] + }, + { + "cell_type": "markdown", + "id": "237523fa", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Content embedding creation" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "6361c7a5", + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[19], line 5\u001B[0m, in \u001B[0;36mgetDF\u001B[0;34m(path)\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mopen\u001B[39m(path, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124m'\u001B[39m) \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[0;32m----> 5\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m line \u001B[38;5;129;01min\u001B[39;00m \u001B[43mf\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreadlines\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m:\n\u001B[1;32m 6\u001B[0m df[i] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28meval\u001B[39m(line)\n", + "File \u001B[0;32m/usr/lib/python3.10/codecs.py:319\u001B[0m, in \u001B[0;36mBufferedIncrementalDecoder.decode\u001B[0;34m(self, input, final)\u001B[0m\n\u001B[1;32m 317\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mNotImplementedError\u001B[39;00m\n\u001B[0;32m--> 319\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;21mdecode\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;28minput\u001B[39m, final\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mFalse\u001B[39;00m):\n\u001B[1;32m 320\u001B[0m \u001B[38;5;66;03m# decode input (taking the buffer into account)\u001B[39;00m\n\u001B[1;32m 321\u001B[0m data \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mbuffer \u001B[38;5;241m+\u001B[39m \u001B[38;5;28minput\u001B[39m\n", + "\u001B[0;31mKeyboardInterrupt\u001B[0m: ", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[19], line 11\u001B[0m\n\u001B[1;32m 7\u001B[0m i \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n\u001B[1;32m 9\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m pd\u001B[38;5;241m.\u001B[39mDataFrame\u001B[38;5;241m.\u001B[39mfrom_dict(df, orient\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mindex\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m---> 11\u001B[0m df \u001B[38;5;241m=\u001B[39m \u001B[43mgetDF\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmetadata_path\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 12\u001B[0m df\u001B[38;5;241m.\u001B[39mhead()\n", + "Cell \u001B[0;32mIn[19], line 5\u001B[0m, in \u001B[0;36mgetDF\u001B[0;34m(path)\u001B[0m\n\u001B[1;32m 3\u001B[0m df \u001B[38;5;241m=\u001B[39m {}\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mopen\u001B[39m(path, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mr\u001B[39m\u001B[38;5;124m'\u001B[39m) \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[0;32m----> 5\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m line \u001B[38;5;129;01min\u001B[39;00m \u001B[43mf\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mreadlines\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m:\n\u001B[1;32m 6\u001B[0m df[i] \u001B[38;5;241m=\u001B[39m \u001B[38;5;28meval\u001B[39m(line)\n\u001B[1;32m 7\u001B[0m i \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m\n", + "\u001B[0;31mKeyboardInterrupt\u001B[0m: " + ] + } + ], + "source": [ + "def getDF(path):\n", + " i = 0\n", + " df = {}\n", + " with open(path, 'r') as f:\n", + " for line in f.readlines():\n", + " df[i] = eval(line)\n", + " i += 1\n", + "\n", + " return pd.DataFrame.from_dict(df, orient=\"index\")\n", + "\n", + "df = getDF(metadata_path)\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "971fa89c", + "metadata": {}, + "outputs": [], + "source": [ + "def preprocess(row: pd.Series):\n", + " row = row.fillna(\"None\")\n", + " return f\"Title: {row['title']}. Categories: {', '.join(row['categories'][0])}. Description: {row['description']}.\"\n", + "\n", + "\n", + "def get_data(metadata_df, item_ids_mapping_df):\n", + " filtered_df = metadata_df.join(\n", + " item_ids_mapping_df, \n", + " left_on=\"asin\", \n", + " right_on='old_item_id', \n", + " how=\"inner\"\n", + " ).select(pl.col('new_item_id'), pl.col('title'), pl.col('description'), pl.col('categories'))\n", + "\n", + " filtered_df = filtered_df.to_pandas()\n", + " filtered_df[\"combined_text\"] = filtered_df.apply(preprocess, axis=1)\n", + "\n", + " return filtered_df\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b0dd5d5", + "metadata": {}, + "outputs": [], + "source": [ + "data = get_data(pl.from_pandas(df), item_ids_mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12e622ff", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda:6')\n", + "\n", + "model_name = \"huggyllama/llama-7b\"\n", + "tokenizer = LlamaTokenizer.from_pretrained(model_name)\n", + "\n", + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "model = LlamaModel.from_pretrained(model_name)\n", + "model = model.to(device)\n", + "model = model.eval()\n", + "\n", + "\n", + "class MyDataset:\n", + " def __init__(self, data):\n", + " self._data = list(zip(data.to_dict()['new_item_id'].values(), data.to_dict()['combined_text'].values()))\n", + "\n", + " def __len__(self):\n", + " return len(self._data)\n", + "\n", + " def __getitem__(self, idx):\n", + " text = self._data[idx][1]\n", + " inputs = tokenizer(text, return_tensors=\"pt\", max_length=1024, truncation=True, padding=\"max_length\")\n", + " return {\n", + " 'item_id': self._data[idx][0],\n", + " 'input_ids': inputs['input_ids'][0],\n", + " 'attention_mask': inputs['attention_mask'][0]\n", + " }\n", + " \n", + "\n", + "dataset = MyDataset(data)\n", + "loader = DataLoader(dataset, batch_size=8, drop_last=False, shuffle=False, num_workers=10)\n", + "\n", + "\n", + "new_df = {\n", + " 'item_id': [],\n", + " 'embedding': []\n", + "}\n", + "\n", + "for batch in tqdm(loader):\n", + " with torch.inference_mode():\n", + " outputs = model(\n", + " input_ids=batch[\"input_ids\"].to(device), \n", + " attention_mask=batch[\"attention_mask\"].to(device)\n", + " )\n", + " embeddings = outputs.last_hidden_state\n", + " \n", + " embeddings = outputs.last_hidden_state # (bs, sl, ed)\n", + " embeddings[(~batch[\"attention_mask\"].bool())] = 0. # (bs, sl, ed)\n", + "\n", + " new_df['item_id'] += batch['item_id'].tolist()\n", + " new_df['embedding'] += embeddings.mean(dim=1).tolist() # (bs, ed)\n", + "\n", + "\n", + "with open(embeddings_output_path, 'wb') as f:\n", + " pickle.dump(new_df, f)\n" + ] + }, + { + "cell_type": "markdown", + "id": "a6fffc4a-85f1-424e-b460-29e526df3317", + "metadata": {}, + "source": [ + "# Test" + ] + }, + { + "cell_type": "code", + "id": "1f922431-e3c1-4587-86d1-04a58eb8ffee", + "metadata": { + "ExecuteTime": { + "end_time": "2025-11-29T16:36:01.750784Z", + "start_time": "2025-11-29T16:36:01.638604Z" + } + }, + "source": [ + "# df = pl.read_parquet(interactions_output_parquet_path)\n", + "# \n", + "# df_timestamps = df.select(\n", + "# pl.col(\"timestamps\").explode()\n", + "# )\n", + "# min_time = df_timestamps.select(pl.col(\"timestamps\").min()).item()\n", + "# max_time = df_timestamps.select(pl.col(\"timestamps\").max()).item()\n", + "# \n", + "# cutoffs = [\n", + "# min_time + (max_time - min_time) * 0.7, # 70%\n", + "# min_time + (max_time - min_time) * 0.8, # 80%\n", + "# min_time + (max_time - min_time) * 0.9, # 90%\n", + "# ]\n" + ], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'pl' is not defined", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", + "Cell \u001B[0;32mIn[1], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m df \u001B[38;5;241m=\u001B[39m \u001B[43mpl\u001B[49m\u001B[38;5;241m.\u001B[39mread_parquet(interactions_output_parquet_path)\n\u001B[1;32m 3\u001B[0m df_timestamps \u001B[38;5;241m=\u001B[39m df\u001B[38;5;241m.\u001B[39mselect(\n\u001B[1;32m 4\u001B[0m pl\u001B[38;5;241m.\u001B[39mcol(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtimestamps\u001B[39m\u001B[38;5;124m\"\u001B[39m)\u001B[38;5;241m.\u001B[39mexplode()\n\u001B[1;32m 5\u001B[0m )\n\u001B[1;32m 6\u001B[0m min_time \u001B[38;5;241m=\u001B[39m df_timestamps\u001B[38;5;241m.\u001B[39mselect(pl\u001B[38;5;241m.\u001B[39mcol(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mtimestamps\u001B[39m\u001B[38;5;124m\"\u001B[39m)\u001B[38;5;241m.\u001B[39mmin())\u001B[38;5;241m.\u001B[39mitem()\n", + "\u001B[0;31mNameError\u001B[0m: name 'pl' is not defined" + ] + } + ], + "execution_count": 1 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/sigir/exps_data.ipynb b/sigir/exps_data.ipynb new file mode 100644 index 0000000..bf170a5 --- /dev/null +++ b/sigir/exps_data.ipynb @@ -0,0 +1,377 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e2462a97-6705-44e1-a232-4dd78a5dfc85", + "metadata": {}, + "outputs": [], + "source": [ + "import polars as pl\n", + "import json\n", + "from typing import List, Dict" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fd38624d-5796-4aa5-929f-7e82c5544f6c", + "metadata": {}, + "outputs": [], + "source": [ + "interactions_output_parquet_path = '../data/Beauty_new/inter_new.parquet'\n", + "df = pl.read_parquet(interactions_output_parquet_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee127317-66b8-4f22-9109-94bcb8b1f1ae", + "metadata": {}, + "outputs": [], + "source": [ + "def split_session_by_timestamps(\n", + " df: pl.DataFrame,\n", + " time_cutoffs: List[int],\n", + " output_dir: str = None,\n", + ") -> List[Dict[int, List[int]]]:\n", + " \n", + " result_dicts = []\n", + " \n", + " def extract_interval(df_source, start, end=None):\n", + " q = df_source.lazy()\n", + " q = q.explode([\"item_ids\", \"timestamps\"])\n", + " \n", + " if end is not None:\n", + " q = q.filter(\n", + " (pl.col(\"timestamps\") >= start) & \n", + " (pl.col(\"timestamps\") < end)\n", + " )\n", + " else:\n", + " q = q.filter(\n", + " pl.col(\"timestamps\") >= start\n", + " )\n", + " \n", + " q = q.group_by(\"uid\").agg([\n", + " pl.col(\"item_ids\").alias(\"item_ids\")\n", + " ]).sort(\"uid\")\n", + " \n", + " return q.collect()\n", + " \n", + " intervals = []\n", + " current_start = 0\n", + " for cutoff in time_cutoffs:\n", + " intervals.append((current_start, cutoff))\n", + " current_start = cutoff\n", + " # от последнего cutoff до бесконечности\n", + " intervals.append((current_start, None))\n", + " \n", + " for start, end in intervals:\n", + " subset = extract_interval(df, start, end)\n", + " \n", + " json_dict = {}\n", + " for user_id, item_ids in subset.iter_rows():\n", + " json_dict[user_id] = item_ids\n", + " \n", + " result_dicts.append(json_dict)\n", + " \n", + " if output_dir:\n", + " if end is not None:\n", + " filename = f\"inter_new_[{start}_{end}).json\"\n", + " else:\n", + " filename = f\"inter_new_[{start}_inf).json\"\n", + " \n", + " filepath = f\"{output_dir}/{filename}\"\n", + " with open(filepath, 'w') as f:\n", + " json.dump(json_dict, f, indent=2)\n", + " \n", + " print(f\"Сохранено: {filepath}\")\n", + " \n", + " return result_dicts" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "efc8b582-dd8a-4299-9c49-de906251de8a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cutoffs: [1402444800, 1403654400, 1404864000]\n", + "✓ Сохранено: ../sigir/Beauty_new/splits/raw/inter_new_[0_1402444800).json\n", + "✓ Сохранено: ../sigir/Beauty_new/splits/raw/inter_new_[1402444800_1403654400).json\n", + "✓ Сохранено: ../sigir/Beauty_new/splits/raw/inter_new_[1403654400_1404864000).json\n", + "✓ Сохранено: ../sigir/Beauty_new/splits/raw/inter_new_[1404864000_inf).json\n", + "Part 0 [Base]: 22029 users\n", + "Part 1 [Week -6]: 1854 users\n", + "Part 2 [Week -4]: 1945 users\n", + "Part 3 [Week -2]: 1381 users\n" + ] + } + ], + "source": [ + "global_max_time = df.select(\n", + " pl.col(\"timestamps\").explode().max()\n", + ").item()\n", + "\n", + "days_val = 14\n", + "window_sec = days_val * 24 * 3600 \n", + "\n", + "cutoff_test_start = global_max_time - window_sec # T - 2w\n", + "cutoff_val_start = global_max_time - 2 * window_sec # T - 4w\n", + "cutoff_gap_start = global_max_time - 3 * window_sec # T - 6w\n", + "\n", + "cutoffs = [\n", + " int(cutoff_gap_start),\n", + " int(cutoff_val_start),\n", + " int(cutoff_test_start)\n", + "]\n", + "\n", + "print(f\"Cutoffs: {cutoffs}\")\n", + "\n", + "split_files = split_session_by_timestamps(\n", + " df, \n", + " cutoffs, \n", + " output_dir=\"../sigir/Beauty_new/splits/raw\"\n", + ")\n", + "\n", + "names = [\"Base\", \"Week -6\", \"Week -4\", \"Week -2\"]\n", + "for i, d in enumerate(split_files):\n", + " print(f\"Part {i} [{names[i]}]: {len(d)} users\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "d5ba172e-b430-40a3-a4fa-64366d02a015", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "--- Creating Experiments ---\n" + ] + } + ], + "source": [ + "def merge_and_save(parts_to_merge, dirr, output_name):\n", + " merged = {}\n", + " print(f\"Merging {len(parts_to_merge)} files into {output_name}...\")\n", + " \n", + " for part in parts_to_merge:\n", + " # with open(fp, 'r') as f:\n", + " # part = json.load(f)\n", + " for uid, items in part.items():\n", + " if uid not in merged:\n", + " merged[uid] = []\n", + " merged[uid].extend(items)\n", + " \n", + " out_path = f\"{dirr}/{output_name}\"\n", + " with open(out_path, 'w') as f:\n", + " json.dump(merged, f)\n", + " print(f\"Done: {out_path} (Users: {len(merged)})\")\n", + "\n", + "p0, p1, p2, p3 = split_files[0], split_files[1], split_files[2], split_files[3]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "d116b7e0-9bf9-4104-86a0-69788a70cc14", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Merging 2 files into exp_4_inter_tiger_train.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json (Users: 22129)\n", + "Merging 2 files into exp_4.1_inter_semantics_train.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json (Users: 22129)\n", + "Merging 1 files into exp_4.2_inter_semantics_train_short.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.2_inter_semantics_train_short.json (Users: 22029)\n", + "Merging 3 files into exp_4.3_inter_semantics_train_leak.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/exp_4.3_inter_semantics_train_leak.json (Users: 22265)\n", + "Merging 1 files into test_set.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/test_set.json (Users: 1381)\n", + "Merging 1 files into valid_skip_set.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/valid_skip_set.json (Users: 1945)\n", + "\n", + "All done!\n" + ] + } + ], + "source": [ + "EXP_DIR = \"../sigir/Beauty_new/splits/exp_data\"\n", + "\n", + "# Tiger: P0+P1\n", + "merge_and_save([p0, p1], EXP_DIR, \"exp_4_inter_tiger_train.json\")\n", + "\n", + "# 1. Exp 4.1\n", + "# Semantics: P0+P1 (Всё кроме валида и теста)\n", + "merge_and_save([p0, p1], EXP_DIR, \"exp_4.1_inter_semantics_train.json\")\n", + "\n", + "# 2. Exp 4.2\n", + "# Semantics: P0 (Короче на неделю, без P1)\n", + "merge_and_save([p0], EXP_DIR, \"exp_4.2_inter_semantics_train_short.json\")\n", + "\n", + "# 3. Exp 4.3\n", + "# Semantics: P0+P1+P2 (Видит валидацию)\n", + "merge_and_save([p0, p1, p2], EXP_DIR, \"exp_4.3_inter_semantics_train_leak.json\")\n", + "\n", + "# 4. Test Set (тест всех моделей)\n", + "merge_and_save([p3], EXP_DIR, \"test_set.json\")\n", + "\n", + "# 4. Valid Set (пропуск, имитируется разница трейна и теста чтобы потом дообучать)\n", + "merge_and_save([p2], EXP_DIR, \"valid_skip_set.json\")\n", + "\n", + "print(\"\\nAll done!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "9ae1d1e5-567d-471a-8f83-4039ecacc8d2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Merging 4 files into all_set.json...\n", + "✓ Done: ../sigir/Beauty_new/splits/exp_data/all_set.json (Users: 22363)\n" + ] + } + ], + "source": [ + "merge_and_save([p0, p1, p2, p3], EXP_DIR, \"all_set.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "328de16c-f61d-45be-8a72-5f0eaef612e8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Проверка Train сетов (должны быть префиксами):\n", + "✅ [ПРЕФИКСЫ] Все 22129 массивов ОК. Полных совпадений: 19410\n", + "✅ [ПРЕФИКСЫ] Все 22029 массивов ОК. Полных совпадений: 18191\n", + "✅ [ПРЕФИКСЫ] Все 22265 массивов ОК. Полных совпадений: 20982\n", + "✅ [ПРЕФИКСЫ] Все 22129 массивов ОК. Полных совпадений: 19410\n", + "\n", + "Проверка Test сета (должен быть суффиксом):\n", + "✅ [СУФФИКСЫ] Все 1381 массивов ОК. Полных совпадений: 98\n", + "\n", + "(Контроль) Проверка Test сета как префикса (должна упасть):\n", + "❌ [ПРЕФИКСЫ] Найдено 1283 ошибок.\n" + ] + } + ], + "source": [ + "with open(\"../data/Beauty/inter_new.json\", 'r') as f:\n", + " old_inter_new = json.load(f)\n", + "\n", + "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.1_inter_semantics_train.json\", 'r') as ff:\n", + " first_sem = json.load(ff)\n", + " \n", + "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.2_inter_semantics_train_short.json\", 'r') as ff:\n", + " second_sem = json.load(ff)\n", + " \n", + "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4.3_inter_semantics_train_leak.json\", 'r') as ff:\n", + " third_sem = json.load(ff)\n", + " \n", + "with open(\"../sigir/Beauty_new/splits/exp_data/exp_4_inter_tiger_train.json\", 'r') as ff:\n", + " tiger_sem = json.load(ff)\n", + "\n", + "with open(\"../sigir/Beauty_new/splits/exp_data/test_set.json\", 'r') as ff:\n", + " test_sem = json.load(ff)\n", + "\n", + "def check_prefix_match(full_data, subset_data, check_suffix=False):\n", + " \"\"\"\n", + " check_suffix=True включит режим проверки суффиксов (для теста).\n", + " \"\"\"\n", + " mismatch_count = 0\n", + " full_match_count = 0\n", + " \n", + " for user, sub_items in subset_data.items():\n", + " if user not in full_data:\n", + " print(f\"Юзер {user} не найден в исходном файле!\")\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " full_items = full_data[user]\n", + " \n", + " if not check_suffix:\n", + " if len(sub_items) > len(full_items):\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " if full_items[:len(sub_items)] == sub_items:\n", + " if len(full_items) == len(sub_items):\n", + " full_match_count += 1\n", + " else:\n", + " mismatch_count += 1\n", + " \n", + " else:\n", + " if len(sub_items) > len(full_items):\n", + " mismatch_count += 1\n", + " continue\n", + " \n", + " if full_items[-len(sub_items):] == sub_items:\n", + " if len(full_items) == len(sub_items):\n", + " full_match_count += 1\n", + " else:\n", + " mismatch_count += 1\n", + "\n", + " mode = \"СУФФИКСЫ\" if check_suffix else \"ПРЕФИКСЫ\"\n", + " \n", + " if mismatch_count == 0:\n", + " print(f\"[{mode}] Все {len(subset_data)} массивов ОК. Полных совпадений: {full_match_count}\")\n", + " else:\n", + " print(f\"[{mode}] Найдено {mismatch_count} ошибок.\")\n", + "\n", + "print(\"Проверка Train сетов (должны быть префиксами):\")\n", + "check_prefix_match(old_inter_new, first_sem)\n", + "check_prefix_match(old_inter_new, second_sem)\n", + "check_prefix_match(old_inter_new, third_sem)\n", + "check_prefix_match(old_inter_new, tiger_sem)\n", + "\n", + "print(\"\\nПроверка Test сета (должен быть суффиксом):\")\n", + "check_prefix_match(old_inter_new, test_sem, check_suffix=True)\n", + "\n", + "print(\"\\n(Контроль) Проверка Test сета как префикса (должна упасть):\")\n", + "check_prefix_match(old_inter_new, test_sem, check_suffix=False)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}