diff --git a/.gitignore b/.gitignore index 5576121c..ae0a51ea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,10 @@ .idea __pycache__ data/* -tensorboard_logs/* +*tensorboard_logs*/* +saved_logs/* +.venv +papers checkpoints/* +*.prof +uv.lock diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..24ee5b1b --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/configs/train/letter.json b/configs/train/letter.json new file mode 100644 index 00000000..576106b7 --- /dev/null +++ b/configs/train/letter.json @@ -0,0 +1,186 @@ +{ + "experiment_name": "letter_data", + "best_metric": "validation/ndcg@20", + "train_epochs_num": 100, + "dataset": { + "type": "letter_full", + "path_to_data_dir": "../data", + "name": "Beauty_letter", + "max_sequence_length": 50, + "samplers": { + "type": "last_item_prediction", + "negative_sampler_type": "random" + }, + "beauty_inter_json": "../../LETTER/data/Beauty/Beauty.inter.json" + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "letter", + "beauty_index_json": "../../LETTER/data/Beauty/Beauty.index.json", + "semantic_length": 4 + }, + "drop_last": true, + "shuffle": true + }, + "validation": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "letter", + "beauty_index_json": "../../LETTER/data/Beauty/Beauty.index.json", + "semantic_length": 4 + }, + "drop_last": false, + "shuffle": false + } + }, + "model": { + "type": "tiger", + "rqvae_train_config_path": "../configs/train/rqvae_train_config.json", + "rqvae_checkpoint_path": "../checkpoints/rqvae_beauty_final_state.pth", + "embs_extractor_path": "../data/Beauty/rqvae/data_full.pt", + "sequence_prefix": "item", + "predictions_prefix": "logits", + "positive_prefix": "labels", + "labels_prefix": "labels", + "embedding_dim": 64, + "num_heads": 2, + "num_encoder_layers": 2, + "num_decoder_layers": 2, + "dim_feedforward": 256, + "dropout": 0.3, + "activation": "gelu", + "layer_norm_eps": 1e-9, + "initializer_range": 0.02 + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 0.001 + }, + "clip_grad_threshold": 5.0 + }, + "loss": { + "type": "composite", + "losses": [ + { + "type": "ce", + "predictions_prefix": "logits", + "labels_prefix": "semantic.labels", + "weight": 1.0, + "output_prefix": "semantic_loss" + }, + { + "type": "ce", + "predictions_prefix": "dedup.logits", + "labels_prefix": "dedup.labels", + "weight": 1.0, + "output_prefix": "dedup_loss" + } + ], + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + }, + { + "type": "validation", + "on_step": 1024, + "pred_prefix": "logits", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "ndcg", + "k": 20 + }, + "recall@5": { + "type": "recall", + "k": 5 + }, + "recall@10": { + "type": "recall", + "k": 10 + }, + "recall@20": { + "type": "recall", + "k": 20 + }, + "coverage@5": { + "type": "coverage", + "k": 5 + }, + "coverage@10": { + "type": "coverage", + "k": 10 + }, + "coverage@20": { + "type": "coverage", + "k": 20 + } + } + }, + { + "type": "eval", + "on_step": 2048, + "pred_prefix": "logits", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "ndcg", + "k": 20 + }, + "recall@5": { + "type": "recall", + "k": 5 + }, + "recall@10": { + "type": "recall", + "k": 10 + }, + "recall@20": { + "type": "recall", + "k": 20 + }, + "coverage@5": { + "type": "coverage", + "k": 5 + }, + "coverage@10": { + "type": "coverage", + "k": 10 + }, + "coverage@20": { + "type": "coverage", + "k": 20 + } + } + } + ] + } + } + \ No newline at end of file diff --git a/configs/train/rqvae_train_config.json b/configs/train/rqvae_train_config.json new file mode 100644 index 00000000..8b7d8bd9 --- /dev/null +++ b/configs/train/rqvae_train_config.json @@ -0,0 +1,65 @@ +{ + "experiment_name": "rqvae_beauty", + "train_epochs_num": 200, + "dataset": { + "type": "rqvae", + "path_to_data_dir": "../data", + "name": "Beauty", + "samplers": { + "type": "identity" + } + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "embed" + }, + "drop_last": false, + "shuffle": true + }, + "validation": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "embed" + }, + "drop_last": false, + "shuffle": false + } + }, + "model": { + "type": "rqvae", + "embedding_dim": 512, + "hidden_dim": 64, + "n_iter": 100, + "codebook_sizes": [256, 256, 256], + "should_init_codebooks": true, + "should_reinit_unused_clusters": true, + "initializer_range": 0.02 + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 5e-5 + }, + "clip_grad_threshold": 5.0 + }, + "loss": { + "type": "rqvae_loss", + "beta": 0.25, + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + } + ] + } +} diff --git a/configs/train/rqvae_train_grid_config.json b/configs/train/rqvae_train_grid_config.json new file mode 100644 index 00000000..e90a6cfa --- /dev/null +++ b/configs/train/rqvae_train_grid_config.json @@ -0,0 +1,86 @@ +{ + "experiment_name": "rqvae_beauty_grid", + "train_epochs_num": 50, + "dataset": { + "type": "rqvae", + "path_to_data_dir": "../data", + "name": "Beauty", + "samplers": { + "type": "identity" + } + }, + "dataset_params": { + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "embed" + }, + "drop_last": false, + "shuffle": true + }, + "validation": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "embed" + }, + "drop_last": false, + "shuffle": false + } + }, + "model": { + "type": "rqvae", + "input_dim": 512, + "codebook_sizes": [256, 256, 256, 256], + "should_init_codebooks": true, + "should_reinit_unused_clusters": true, + "initializer_range": 0.02 + }, + "model_params": { + "n_iter": [ + 100, + 500, + 2000 + ], + "hidden_dim": [ + 128, + 512, + 2048 + ] + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 1e-4 + }, + "clip_grad_threshold": 5.0, + "scheduler": { + "type": "step", + "step_size": 100, + "gamma": 0.96 + } + }, + "optimizer_params": { + }, + "loss": { + "type": "rqvae_loss", + "beta": 0.25, + "output_prefix": "loss" + }, + "loss_params": { + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + } + ] + } +} diff --git a/configs/train/sasrec_full_train_config.json b/configs/train/sasrec_full_train_config.json new file mode 100644 index 00000000..bbff4ee5 --- /dev/null +++ b/configs/train/sasrec_full_train_config.json @@ -0,0 +1,168 @@ +{ + "experiment_name": "sasrec_full_beauty", + "best_metric": "validation/ndcg@20", + "train_epochs_num": 100, + "dataset": { + "type": "sequence_full", + "path_to_data_dir": "../data", + "name": "Beauty", + "max_sequence_length": 50, + "samplers": { + "type": "last_item_prediction", + "negative_sampler_type": "random" + } + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "basic" + }, + "drop_last": true, + "shuffle": true + }, + "validation": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "basic" + }, + "drop_last": false, + "shuffle": false + } + }, + "model": { + "type": "sasrec_full", + "sequence_prefix": "item", + "positive_prefix": "labels", + "negative_prefix": "negative", + "candidate_prefix": "candidates", + "embedding_dim": 64, + "num_heads": 2, + "num_layers": 2, + "dim_feedforward": 256, + "dropout": 0.3, + "activation": "gelu", + "layer_norm_eps": 1e-9, + "initializer_range": 0.02 + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 0.001 + }, + "clip_grad_threshold": 5.0 + }, + "loss": { + "type": "composite", + "losses": [ + { + "type": "ce", + "predictions_prefix": "logits", + "labels_prefix": "labels", + "output_prefix": "downstream_loss" + } + ], + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + }, + { + "type": "validation", + "on_step": 1024, + "pred_prefix": "logits", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "ndcg", + "k": 20 + }, + "recall@5": { + "type": "recall", + "k": 5 + }, + "recall@10": { + "type": "recall", + "k": 10 + }, + "recall@20": { + "type": "recall", + "k": 20 + }, + "coverage@5": { + "type": "coverage", + "k": 5 + }, + "coverage@10": { + "type": "coverage", + "k": 10 + }, + "coverage@20": { + "type": "coverage", + "k": 20 + } + } + }, + { + "type": "eval", + "on_step": 2048, + "pred_prefix": "logits", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "ndcg", + "k": 20 + }, + "recall@5": { + "type": "recall", + "k": 5 + }, + "recall@10": { + "type": "recall", + "k": 10 + }, + "recall@20": { + "type": "recall", + "k": 20 + }, + "coverage@5": { + "type": "coverage", + "k": 5 + }, + "coverage@10": { + "type": "coverage", + "k": 10 + }, + "coverage@20": { + "type": "coverage", + "k": 20 + } + } + } + ] + } +} \ No newline at end of file diff --git a/configs/train/sasrec_in_batch_train_config.json b/configs/train/sasrec_in_batch_train_config.json index 8bde73f7..090a64d9 100644 --- a/configs/train/sasrec_in_batch_train_config.json +++ b/configs/train/sasrec_in_batch_train_config.json @@ -1,14 +1,14 @@ { - "experiment_name": "sasrec_in_batch_test", - "best_metric": "eval/ndcg@20", + "experiment_name": "sasrec_in_batch_beauty", + "best_metric": "validation/ndcg@20", + "train_epochs_num": 100, "dataset": { - "type": "sequence", + "type": "sequence_full", "path_to_data_dir": "../data", "name": "Beauty", "max_sequence_length": 50, "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", + "type": "last_item_prediction", "negative_sampler_type": "random" } }, @@ -35,7 +35,7 @@ "model": { "type": "sasrec_in_batch", "sequence_prefix": "item", - "positive_prefix": "positive", + "positive_prefix": "labels", "negative_prefix": "negative", "candidate_prefix": "candidates", "embedding_dim": 64, @@ -44,7 +44,6 @@ "dim_feedforward": 256, "dropout": 0.3, "activation": "gelu", - "use_ce": true, "layer_norm_eps": 1e-9, "initializer_range": 0.02 }, @@ -79,7 +78,7 @@ }, { "type": "validation", - "on_step": 64, + "on_step": 1024, "pred_prefix": "logits", "labels_prefix": "labels", "metrics": { @@ -106,12 +105,24 @@ "recall@20": { "type": "recall", "k": 20 + }, + "coverage@5": { + "type": "coverage", + "k": 5 + }, + "coverage@10": { + "type": "coverage", + "k": 10 + }, + "coverage@20": { + "type": "coverage", + "k": 20 } } }, { "type": "eval", - "on_step": 256, + "on_step": 2048, "pred_prefix": "logits", "labels_prefix": "labels", "metrics": { @@ -138,9 +149,21 @@ "recall@20": { "type": "recall", "k": 20 + }, + "coverage@5": { + "type": "coverage", + "k": 5 + }, + "coverage@10": { + "type": "coverage", + "k": 10 + }, + "coverage@20": { + "type": "coverage", + "k": 20 } } } ] } -} +} \ No newline at end of file diff --git a/configs/train/sasrec_train_config.json b/configs/train/sasrec_real_train_config.json similarity index 91% rename from configs/train/sasrec_train_config.json rename to configs/train/sasrec_real_train_config.json index aa29a029..c43c3d42 100644 --- a/configs/train/sasrec_train_config.json +++ b/configs/train/sasrec_real_train_config.json @@ -1,13 +1,15 @@ { - "experiment_name": "sasrec_beauty", + "experiment_name": "sasrec_real_beauty", "best_metric": "validation/ndcg@20", + "train_epochs_num": 100, "dataset": { - "type": "sequence", + "type": "sequence_full", "path_to_data_dir": "../data", "name": "Beauty", "max_sequence_length": 50, "samplers": { - "type": "next_item_prediction", + "type": "last_item_prediction", + "num_negatives_train": 1, "negative_sampler_type": "random" } }, @@ -32,9 +34,9 @@ } }, "model": { - "type": "sasrec", + "type": "sasrec_real", "sequence_prefix": "item", - "positive_prefix": "positive", + "positive_prefix": "labels", "negative_prefix": "negative", "candidate_prefix": "candidates", "embedding_dim": 64, @@ -58,7 +60,7 @@ "type": "composite", "losses": [ { - "type": "sasrec", + "type": "sasrec_real", "positive_prefix": "positive_scores", "negative_prefix": "negative_scores", "output_prefix": "downstream_loss" @@ -76,7 +78,7 @@ }, { "type": "validation", - "on_step": 64, + "on_step": 1024, "pred_prefix": "logits", "labels_prefix": "labels", "metrics": { @@ -120,7 +122,7 @@ }, { "type": "eval", - "on_step": 256, + "on_step": 2048, "pred_prefix": "logits", "labels_prefix": "labels", "metrics": { @@ -164,4 +166,4 @@ } ] } -} +} \ No newline at end of file diff --git a/configs/train/sasrec_semantic_train_config.json b/configs/train/sasrec_semantic_train_config.json new file mode 100644 index 00000000..f37c6fc5 --- /dev/null +++ b/configs/train/sasrec_semantic_train_config.json @@ -0,0 +1,172 @@ +{ + "experiment_name": "sasrec_semantic_learnable_uid_beauty", + "best_metric": "validation/ndcg@20", + "train_epochs_num": 100, + "dataset": { + "type": "sequence_full", + "path_to_data_dir": "../data", + "name": "Beauty", + "max_sequence_length": 50, + "samplers": { + "type": "last_item_prediction", + "num_negatives_train": 1, + "negative_sampler_type": "random" + } + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "basic" + }, + "drop_last": true, + "shuffle": true + }, + "validation": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "basic" + }, + "drop_last": false, + "shuffle": false + } + }, + "model": { + "type": "sasrec_semantic", + "rqvae_train_config_path": "../configs/train/rqvae_train_config.json", + "rqvae_checkpoint_path": "../checkpoints/rqvae_beauty_final_state.pth", + "embs_extractor_path": "../data/Beauty/rqvae/data_full.pt", + "sequence_prefix": "item", + "positive_prefix": "labels", + "negative_prefix": "negative", + "candidate_prefix": "candidates", + "embedding_dim": 64, + "num_heads": 2, + "num_layers": 2, + "dim_feedforward": 256, + "dropout": 0.3, + "activation": "gelu", + "layer_norm_eps": 1e-9, + "initializer_range": 0.02 + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 0.001 + }, + "clip_grad_threshold": 5.0 + }, + "loss": { + "type": "composite", + "losses": [ + { + "type": "sasrec_real", + "positive_prefix": "positive_scores", + "negative_prefix": "negative_scores", + "output_prefix": "downstream_loss" + } + ], + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + }, + { + "type": "validation", + "on_step": 512, + "pred_prefix": "logits", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "ndcg", + "k": 20 + }, + "recall@5": { + "type": "recall", + "k": 5 + }, + "recall@10": { + "type": "recall", + "k": 10 + }, + "recall@20": { + "type": "recall", + "k": 20 + }, + "coverage@5": { + "type": "coverage", + "k": 5 + }, + "coverage@10": { + "type": "coverage", + "k": 10 + }, + "coverage@20": { + "type": "coverage", + "k": 20 + } + } + }, + { + "type": "eval", + "on_step": 1024, + "pred_prefix": "logits", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "ndcg", + "k": 20 + }, + "recall@5": { + "type": "recall", + "k": 5 + }, + "recall@10": { + "type": "recall", + "k": 10 + }, + "recall@20": { + "type": "recall", + "k": 20 + }, + "coverage@5": { + "type": "coverage", + "k": 5 + }, + "coverage@10": { + "type": "coverage", + "k": 10 + }, + "coverage@20": { + "type": "coverage", + "k": 20 + } + } + } + ] + } +} \ No newline at end of file diff --git a/configs/train/tiger_train_config.json b/configs/train/tiger_train_config.json new file mode 100644 index 00000000..4b14df40 --- /dev/null +++ b/configs/train/tiger_train_config.json @@ -0,0 +1,180 @@ +{ + "experiment_name": "tiger_simplified_no_residuals_unfreezed", + "best_metric": "validation/ndcg@20", + "train_epochs_num": 100, + "dataset": { + "type": "sequence_full", + "path_to_data_dir": "../data", + "name": "Beauty", + "max_sequence_length": 50, + "samplers": { + "type": "last_item_prediction", + "negative_sampler_type": "random" + } + }, + "dataloader": { + "train": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "basic" + }, + "drop_last": true, + "shuffle": true + }, + "validation": { + "type": "torch", + "batch_size": 256, + "batch_processor": { + "type": "basic" + }, + "drop_last": false, + "shuffle": false + } + }, + "model": { + "type": "tiger", + "rqvae_train_config_path": "../configs/train/rqvae_train_config.json", + "rqvae_checkpoint_path": "../checkpoints/rqvae_beauty_final_state.pth", + "embs_extractor_path": "../data/Beauty/rqvae/data_full.pt", + "sequence_prefix": "item", + "predictions_prefix": "logits", + "positive_prefix": "labels", + "labels_prefix": "labels", + "embedding_dim": 64, + "num_heads": 2, + "num_encoder_layers": 2, + "num_decoder_layers": 2, + "dim_feedforward": 256, + "dropout": 0.3, + "activation": "gelu", + "layer_norm_eps": 1e-9, + "initializer_range": 0.02 + }, + "optimizer": { + "type": "basic", + "optimizer": { + "type": "adam", + "lr": 0.001 + }, + "clip_grad_threshold": 5.0 + }, + "loss": { + "type": "composite", + "losses": [ + { + "type": "ce", + "predictions_prefix": "logits", + "labels_prefix": "semantic.labels", + "weight": 1.0, + "output_prefix": "semantic_loss" + }, + { + "type": "ce", + "predictions_prefix": "dedup.logits", + "labels_prefix": "dedup.labels", + "weight": 1.0, + "output_prefix": "dedup_loss" + } + ], + "output_prefix": "loss" + }, + "callback": { + "type": "composite", + "callbacks": [ + { + "type": "metric", + "on_step": 1, + "loss_prefix": "loss" + }, + { + "type": "validation", + "on_step": 1024, + "pred_prefix": "logits", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "ndcg", + "k": 20 + }, + "recall@5": { + "type": "recall", + "k": 5 + }, + "recall@10": { + "type": "recall", + "k": 10 + }, + "recall@20": { + "type": "recall", + "k": 20 + }, + "coverage@5": { + "type": "coverage", + "k": 5 + }, + "coverage@10": { + "type": "coverage", + "k": 10 + }, + "coverage@20": { + "type": "coverage", + "k": 20 + } + } + }, + { + "type": "eval", + "on_step": 2048, + "pred_prefix": "logits", + "labels_prefix": "labels", + "metrics": { + "ndcg@5": { + "type": "ndcg", + "k": 5 + }, + "ndcg@10": { + "type": "ndcg", + "k": 10 + }, + "ndcg@20": { + "type": "ndcg", + "k": 20 + }, + "recall@5": { + "type": "recall", + "k": 5 + }, + "recall@10": { + "type": "recall", + "k": 10 + }, + "recall@20": { + "type": "recall", + "k": 20 + }, + "coverage@5": { + "type": "coverage", + "k": 5 + }, + "coverage@10": { + "type": "coverage", + "k": 10 + }, + "coverage@20": { + "type": "coverage", + "k": 20 + } + } + } + ] + } +} diff --git a/modeling/callbacks/base.py b/modeling/callbacks/base.py index d959dd4f..1ae26267 100644 --- a/modeling/callbacks/base.py +++ b/modeling/callbacks/base.py @@ -203,7 +203,6 @@ def __call__(self, inputs, step_num): self._model.eval() with torch.no_grad(): for batch in self._get_dataloader(): - for key, value in batch.items(): batch[key] = value.to(utils.DEVICE) diff --git a/modeling/dataloader/batch_processors.py b/modeling/dataloader/batch_processors.py index 436f98fe..d7567bd3 100644 --- a/modeling/dataloader/batch_processors.py +++ b/modeling/dataloader/batch_processors.py @@ -1,3 +1,6 @@ +import json +import re +from itertools import chain import torch from utils import MetaParent @@ -12,6 +15,14 @@ class IdentityBatchProcessor(BaseBatchProcessor, config_name='identity'): def __call__(self, batch): return torch.tensor(batch) + +class EmbedBatchProcessor(BaseBatchProcessor, config_name='embed'): + + def __call__(self, batch): + ids = torch.tensor([entry['item.id'] for entry in batch]) + embeds = torch.stack([entry['item.embed'] for entry in batch]) + + return {'ids': ids, 'embeddings': embeds} class BasicBatchProcessor(BaseBatchProcessor, config_name='basic'): @@ -35,3 +46,61 @@ def __call__(self, batch): processed_batch[part] = torch.tensor(values, dtype=torch.long) return processed_batch + + +class LetterBatchProcessor(BaseBatchProcessor, config_name='letter'): + def __init__(self, mapping, semantic_length): + self._mapping: dict[int, list[int]] = mapping + self._prefixes = ['item', 'labels', 'positive', 'negative'] + self._semantic_length = semantic_length + + @classmethod + def create_from_config(cls, config, **kwargs): + mapping_path = config["beauty_index_json"] + with open(mapping_path, "r") as f: + mapping = json.load(f) + + semantic_length = config["semantic_length"] + + parsed = {} + + for key, semantic_ids in mapping.items(): + numbers = [int(re.search(r'\d+', item).group()) for item in semantic_ids] + assert len(numbers) == semantic_length + parsed[int(key)] = numbers + + return cls(mapping=parsed, semantic_length=semantic_length) + + def __call__(self, batch): + processed_batch = {} + + for key in batch[0].keys(): + if key.endswith('.ids'): + prefix = key.split('.')[0] + assert '{}.length'.format(prefix) in batch[0] + + processed_batch[f'{prefix}.ids'] = [] + processed_batch[f'{prefix}.length'] = [] + + for sample in batch: + processed_batch[f'{prefix}.ids'].extend(sample[f'{prefix}.ids']) + processed_batch[f'{prefix}.length'].append(sample[f'{prefix}.length']) + + for prefix in self._prefixes: + if f"{prefix}.ids" in processed_batch: + ids = processed_batch[f"{prefix}.ids"] + lengths = processed_batch[f"{prefix}.length"] + + mapped_ids = [] + + for _id in ids: + mapped_ids.append(self._mapping[_id]) + + processed_batch[f"semantic_{prefix}.ids"] = list(chain.from_iterable(mapped_ids)) + processed_batch[f"semantic_{prefix}_tensor.ids"] = mapped_ids + processed_batch[f"semantic_{prefix}.length"] = [length * self._semantic_length for length in lengths] + + for part, values in processed_batch.items(): + processed_batch[part] = torch.tensor(values, dtype=torch.long) + + return processed_batch diff --git a/modeling/dataset/base.py b/modeling/dataset/base.py index 42a1516e..835c0578 100644 --- a/modeling/dataset/base.py +++ b/modeling/dataset/base.py @@ -1,4 +1,5 @@ from collections import defaultdict +import json from tqdm import tqdm @@ -195,7 +196,142 @@ def meta(self): 'num_items': self.num_items, 'max_sequence_length': self.max_sequence_length } + + +class SequenceFullDataset(SequenceDataset, config_name='sequence_full'): + @classmethod + def create_from_config(cls, config, **kwargs): + data_dir_path = os.path.join(config['path_to_data_dir'], config['name']) + + train_dataset, train_max_user_id, train_max_item_id, train_seq_len = cls._create_dataset( + dir_path=data_dir_path, + part='train', + max_sequence_length=config['max_sequence_length'], + use_cached=config.get('use_cached', False) + ) + validation_dataset, valid_max_user_id, valid_max_item_id, valid_seq_len = cls._create_dataset( + dir_path=data_dir_path, + part='valid', + max_sequence_length=config['max_sequence_length'], + use_cached=config.get('use_cached', False) + ) + test_dataset, test_max_user_id, test_max_item_id, test_seq_len = cls._create_dataset( + dir_path=data_dir_path, + part='test', + max_sequence_length=config['max_sequence_length'], + use_cached=config.get('use_cached', False) + ) + + max_user_id = max([train_max_user_id, valid_max_user_id, test_max_user_id]) + max_item_id = max([train_max_item_id, valid_max_item_id, test_max_item_id]) + max_seq_len = max([train_seq_len, valid_seq_len, test_seq_len]) + + logger.info('Train dataset size: {}'.format(len(train_dataset))) + logger.info("Validation dataset size: {}".format(len(validation_dataset))) + logger.info('Test dataset size: {}'.format(len(test_dataset))) + logger.info('Max user id: {}'.format(max_user_id)) + logger.info('Max item id: {}'.format(max_item_id)) + logger.info('Max sequence length: {}'.format(max_seq_len)) + + train_interactions = sum(list(map(lambda x: len(x), train_dataset))) # whole user history as a sample + valid_interactions = len(validation_dataset) # each new interaction as a sample + test_interactions = len(test_dataset) # each new interaction as a sample + logger.info('{} dataset sparsity: {}'.format( + config['name'], (train_interactions + valid_interactions + test_interactions) / max_user_id / max_item_id + )) + + train_sampler = TrainSampler.create_from_config( + config['samplers'], + dataset=train_dataset, + num_users=max_user_id, + num_items=max_item_id + ) + validation_sampler = EvalSampler.create_from_config( + config['samplers'], + dataset=validation_dataset, + num_users=max_user_id, + num_items=max_item_id + ) + test_sampler = EvalSampler.create_from_config( + config['samplers'], + dataset=test_dataset, + num_users=max_user_id, + num_items=max_item_id + ) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_users=max_user_id, + num_items=max_item_id, + max_sequence_length=max_seq_len + ) + + @classmethod + def flatten_item_sequence(cls, item_ids): + min_history_length = 3 # TODOPK make this configurable + histories = [] + for i in range(min_history_length, len(item_ids)): + histories.append(item_ids[:i]) + return histories + + @classmethod + def _create_dataset(cls, dir_path, part, max_sequence_length=None, use_cached=False): + max_user_id = 0 + max_item_id = 0 + max_sequence_len = 0 + + if use_cached and os.path.exists(os.path.join(dir_path, '{}.pkl'.format(part))): + logger.info(f'Take cached dataset from {os.path.join(dir_path, "{}.pkl".format(part))}') + with open(os.path.join(dir_path, '{}.pkl'.format(part)), 'rb') as dataset_file: + dataset, max_user_id, max_item_id, max_sequence_len = pickle.load(dataset_file) + else: + logger.info('Cache is forecefully ignored.' if not use_cached else 'No cached dataset has been found.') + logger.info(f'Creating a dataset {os.path.join(dir_path, "{}.txt".format(part))}...') + + dataset_path = os.path.join(dir_path, '{}.txt'.format(part)) + with open(dataset_path, 'r') as f: + data = f.readlines() + + sequence_info = cls._create_sequences(data, max_sequence_length) + user_sequences, item_sequences, max_user_id, max_item_id, max_sequence_len = sequence_info + + # TODOPK check + dataset = [] + for user_id, item_ids in zip(user_sequences, item_sequences): + if part == "train": + flattened_item_ids = cls.flatten_item_sequence(item_ids) + for seq in flattened_item_ids: + dataset.append( + { + "user.ids": [user_id], + "user.length": 1, + "item.ids": seq, + "item.length": len(seq), + } + ) + else: + dataset.append( + { + "user.ids": [user_id], + "user.length": 1, + "item.ids": item_ids, + "item.length": len(item_ids), + } + ) + + logger.info('{} dataset size: {}'.format(part, len(dataset))) + logger.info('{} dataset max sequence length: {}'.format(part, max_sequence_len)) + + with open(os.path.join(dir_path, '{}.pkl'.format(part)), 'wb') as dataset_file: + pickle.dump( + (dataset, max_user_id, max_item_id, max_sequence_len), + dataset_file + ) + + return dataset, max_user_id, max_item_id, max_sequence_len class GraphDataset(BaseDataset, config_name='graph'): @@ -616,3 +752,264 @@ def meta(self): 'num_items': self.num_items, 'max_sequence_length': self.max_sequence_length } + + +class ScientificFullDataset(ScientificDataset, config_name="scientific_full"): + def __init__( + self, + train_sampler, + validation_sampler, + test_sampler, + num_users, + num_items, + max_sequence_length, + ): + self._train_sampler = train_sampler + self._validation_sampler = validation_sampler + self._test_sampler = test_sampler + self._num_users = num_users + self._num_items = num_items + self._max_sequence_length = max_sequence_length + + @classmethod + def create_from_config(cls, config, **kwargs): + data_dir_path = os.path.join(config["path_to_data_dir"], config["name"]) + max_sequence_length = config["max_sequence_length"] + max_user_id, max_item_id = 0, 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + dataset_path = os.path.join(data_dir_path, "{}.txt".format("all_data")) + with open(dataset_path, "r") as f: + data = f.readlines() + + for sample in data: + sample = sample.strip("\n").split(" ") + user_id = int(sample[0]) + item_ids = [int(item_id) for item_id in sample[1:]] + + max_user_id = max(max_user_id, user_id) + max_item_id = max(max_item_id, max(item_ids)) + + assert len(item_ids) >= 5 + + # item_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + # prefix_length: 5, 6, 7, 8, 9, 10 + for prefix_length in range(5, len(item_ids) + 1): + # prefix = [1, 2, 3, 4, 5] + # prefix = [1, 2, 3, 4, 5, 6] + # prefix = [1, 2, 3, 4, 5, 6, 7] + # prefix = [1, 2, 3, 4, 5, 6, 7, 8] + # prefix = [1, 2, 3, 4, 5, 6, 7, 8, 9] + # prefix = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + + + prefix = item_ids[ + :prefix_length + ] # TODOPK no sliding window, only incrmenting sequence from last 50 items + + # prefix[:-2] = [1, 2, 3] + # prefix[:-2] = [1, 2, 3, 4] + # prefix[:-2] = [1, 2, 3, 4, 5] + # prefix[:-2] = [1, 2, 3, 4, 5, 6] + # prefix[:-2] = [1, 2, 3, 4, 5, 6, 7] + # prefix[:-2] = [1, 2, 3, 4, 5, 6, 7, 8] + + train_dataset.append( + { + "user.ids": [user_id], + "user.length": 1, + "item.ids": prefix[:-2][-max_sequence_length:], + "item.length": len(prefix[:-2][-max_sequence_length:]), + } + ) + assert len(prefix[:-2][-max_sequence_length:]) == len( + set(prefix[:-2][-max_sequence_length:]) + ) + + # item_ids[:-1] = [1, 2, 3, 4, 5, 6, 7, 8, 9] + validation_dataset.append( + { + "user.ids": [user_id], + "user.length": 1, + "item.ids": item_ids[:-1][-max_sequence_length:], + "item.length": len(item_ids[:-1][-max_sequence_length:]), + } + ) + assert len(item_ids[:-1][-max_sequence_length:]) == len( + set(item_ids[:-1][-max_sequence_length:]) + ) + + # item_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + test_dataset.append( + { + "user.ids": [user_id], + "user.length": 1, + "item.ids": item_ids[-max_sequence_length:], + "item.length": len(item_ids[-max_sequence_length:]), + } + ) + assert len(item_ids[-max_sequence_length:]) == len( + set(item_ids[-max_sequence_length:]) + ) + + logger.info("Train dataset size: {}".format(len(train_dataset))) + logger.info("Validation dataset size: {}".format(len(validation_dataset))) + logger.info("Test dataset size: {}".format(len(test_dataset))) + logger.info("Max user id: {}".format(max_user_id)) + logger.info("Max item id: {}".format(max_item_id)) + logger.info("Max sequence length: {}".format(max_sequence_length)) + logger.info( + "{} dataset sparsity: {}".format( + config["name"], + (len(train_dataset) + len(test_dataset)) / max_user_id / max_item_id, + ) + ) + + train_sampler = TrainSampler.create_from_config( + config["samplers"], + dataset=train_dataset, + num_users=max_user_id, + num_items=max_item_id, + ) + validation_sampler = EvalSampler.create_from_config( + config["samplers"], + dataset=validation_dataset, + num_users=max_user_id, + num_items=max_item_id, + ) + test_sampler = EvalSampler.create_from_config( + config["samplers"], + dataset=test_dataset, + num_users=max_user_id, + num_items=max_item_id, + ) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_users=max_user_id, + num_items=max_item_id, + max_sequence_length=max_sequence_length, + ) + + +class LetterFullDataset(ScientificFullDataset, config_name="letter_full"): + def __init__( + self, + train_sampler, + validation_sampler, + test_sampler, + num_users, + num_items, + max_sequence_length, + ): + self._train_sampler = train_sampler + self._validation_sampler = validation_sampler + self._test_sampler = test_sampler + self._num_users = num_users + self._num_items = num_items + self._max_sequence_length = max_sequence_length + + @classmethod + def create_from_config(cls, config, **kwargs): + user_interactions_path = os.path.join(config["beauty_inter_json"]) + with open(user_interactions_path, "r") as f: + user_interactions = json.load(f) + + dir_path = os.path.join(config["path_to_data_dir"], config["name"]) + + os.makedirs(dir_path, exist_ok=True) + dataset_path = os.path.join(dir_path, "all_data.txt") + + logger.info(f"Saving data to {dataset_path}") + + # Map from LETTER format to Our format + with open(dataset_path, "w") as f: + for user_id, item_ids in user_interactions.items(): + items_repr = map(str, item_ids) + f.write(f"{user_id} {' '.join(items_repr)}\n") + + dataset = ScientificFullDataset.create_from_config(config, **kwargs) + + return cls( + train_sampler=dataset._train_sampler, + validation_sampler=dataset._validation_sampler, + test_sampler=dataset._test_sampler, + num_users=dataset._num_users, + num_items=dataset._num_items, + max_sequence_length=dataset._max_sequence_length, + ) + + + +class RqVaeDataset(BaseDataset, config_name='rqvae'): + + def __init__( + self, + train_sampler, + validation_sampler, + test_sampler, + num_items + ): + self._train_sampler = train_sampler + self._validation_sampler = validation_sampler + self._test_sampler = test_sampler + self._num_items = num_items + + @classmethod + def create_from_config(cls, config, **kwargs): + data_dir_path = os.path.join(config['path_to_data_dir'], config['name']) + train_dataset, validation_dataset, test_dataset = [], [], [] + + dataset_path = os.path.join(data_dir_path, '{}.pt'.format('data_full')) + df = torch.load(dataset_path, weights_only=False) + + for idx, sample in df.iterrows(): + train_dataset.append({ + 'item.id': idx, + 'item.embed': sample["embeddings"] + }) + + logger.info('Train dataset size: {}'.format(len(train_dataset))) + logger.info('Test dataset size: {}'.format(len(test_dataset))) + + train_sampler = TrainSampler.create_from_config( + config['samplers'], + dataset=train_dataset + ) + validation_sampler = EvalSampler.create_from_config( + config['samplers'], + dataset=validation_dataset + ) + test_sampler = EvalSampler.create_from_config( + config['samplers'], + dataset=test_dataset + ) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=len(df) + ) + + def get_samplers(self): + return self._train_sampler, self._validation_sampler, self._test_sampler + + @property + def num_items(self): + return self._num_items + + @property + def max_sequence_length(self): + return self._max_sequence_length + + @property + def meta(self): + return { + 'num_items': self.num_items, + 'train_sampler': self._train_sampler + } diff --git a/modeling/dataset/negative_samplers/random.py b/modeling/dataset/negative_samplers/random.py index b83042b0..81bf9038 100644 --- a/modeling/dataset/negative_samplers/random.py +++ b/modeling/dataset/negative_samplers/random.py @@ -1,33 +1,26 @@ from collections import defaultdict -from tqdm import tqdm - -from dataset.negative_samplers.base import BaseNegativeSampler - import numpy as np +from dataset.negative_samplers.base import BaseNegativeSampler +from tqdm import tqdm class RandomNegativeSampler(BaseNegativeSampler, config_name='random'): - @classmethod def create_from_config(cls, _, **kwargs): return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'] + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], ) def generate_negative_samples(self, sample, num_negatives): - user_id = sample['user.ids'][0] - all_items = list(range(1, self._num_items + 1)) - np.random.shuffle(all_items) + user_id = sample["user.ids"][0] + negatives = set() - negatives = [] - running_idx = 0 - while len(negatives) < num_negatives and running_idx < len(all_items): - negative_idx = all_items[running_idx] - if negative_idx not in self._seen_items[user_id]: - negatives.append(negative_idx) - running_idx += 1 + while len(negatives) < num_negatives: + candidate = np.random.randint(1, self._num_items + 1) + if candidate not in self._seen_items[user_id]: + negatives.add(candidate) - return negatives + return list(negatives) diff --git a/modeling/dataset/samplers/__init__.py b/modeling/dataset/samplers/__init__.py index 8c1da0fc..6ed31eed 100644 --- a/modeling/dataset/samplers/__init__.py +++ b/modeling/dataset/samplers/__init__.py @@ -7,3 +7,4 @@ from .mclsr import MCLSRTrainSampler, MCLSRPredictionEvalSampler from .pop import PopTrainSampler, PopEvalSampler from .s3rec import S3RecPretrainTrainSampler, S3RecPretrainEvalSampler +from .identity import IdentityTrainSampler, IdentityEvalSampler diff --git a/modeling/dataset/samplers/identity.py b/modeling/dataset/samplers/identity.py new file mode 100644 index 00000000..ffe01e23 --- /dev/null +++ b/modeling/dataset/samplers/identity.py @@ -0,0 +1,35 @@ +from dataset.samplers.base import TrainSampler, EvalSampler + +import copy + + +class IdentityTrainSampler(TrainSampler, config_name='identity'): + + def __init__(self, dataset): + super().__init__() + self._dataset = dataset + + @classmethod + def create_from_config(cls, config, **kwargs): + return cls( + dataset=kwargs['dataset'] + ) + + def __getitem__(self, index): + sample = copy.deepcopy(self._dataset[index]) + return sample + + +class IdentityEvalSampler(EvalSampler, config_name='identity'): + def __init__(self, dataset): + self._dataset = dataset + + @classmethod + def create_from_config(cls, config, **kwargs): + return cls( + dataset=kwargs['dataset'] + ) + + def __getitem__(self, index): + sample = copy.deepcopy(self._dataset[index]) + return sample \ No newline at end of file diff --git a/modeling/dataset/samplers/last_item_prediction.py b/modeling/dataset/samplers/last_item_prediction.py index c0a93212..474ef4c5 100644 --- a/modeling/dataset/samplers/last_item_prediction.py +++ b/modeling/dataset/samplers/last_item_prediction.py @@ -1,22 +1,29 @@ -from dataset.samplers.base import TrainSampler, EvalSampler - import copy +from dataset.samplers.base import EvalSampler, TrainSampler +from dataset.negative_samplers.base import BaseNegativeSampler -class LastItemPredictionTrainSampler(TrainSampler, config_name='last_item_prediction'): - - def __init__(self, dataset, num_users, num_items): +class LastItemPredictionTrainSampler(TrainSampler, config_name="last_item_prediction"): + def __init__(self, dataset, num_users, num_items, negative_sampler, num_negatives): super().__init__() self._dataset = dataset self._num_users = num_users self._num_items = num_items + self._negative_sampler = negative_sampler + self._num_negatives = num_negatives @classmethod def create_from_config(cls, config, **kwargs): + negative_sampler = BaseNegativeSampler.create_from_config( + {"type": config["negative_sampler_type"]}, **kwargs + ) + return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], + dataset=kwargs["dataset"], + num_users=kwargs["num_users"], + num_items=kwargs["num_items"], + negative_sampler=negative_sampler, + num_negatives=config.get("num_negatives_train", 0), ) def __getitem__(self, index): @@ -25,16 +32,30 @@ def __getitem__(self, index): item_sequence = sample['item.ids'][:-1] last_item = sample['item.ids'][-1] - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - - 'item.ids': item_sequence, - 'item.length': len(item_sequence), + if self._num_negatives == 0: + return { + "user.ids": sample["user.ids"], + "user.length": sample["user.length"], + "item.ids": item_sequence, + "item.length": len(item_sequence), + "labels.ids": [last_item], + "labels.length": 1, + } + else: + negative_sequence = self._negative_sampler.generate_negative_samples( + sample, self._num_negatives + ) - 'labels.ids': [last_item], - 'labels.length': 1, - } + return { + "user.ids": sample["user.ids"], + "user.length": sample["user.length"], + "item.ids": item_sequence, + "item.length": len(item_sequence), + "labels.ids": [last_item], + "labels.length": 1, + "negative.ids": negative_sequence, + "negative.length": len(negative_sequence), + } class LastItemPredictionEvalSampler(EvalSampler, config_name='last_item_prediction'): diff --git a/modeling/loss/base.py b/modeling/loss/base.py index dc03d6cb..8ec91326 100644 --- a/modeling/loss/base.py +++ b/modeling/loss/base.py @@ -53,6 +53,43 @@ def forward(self, inputs): return total_loss +class SampleLogSoftmaxLoss(TorchLoss, config_name='sample_logsoftmax'): + def __init__(self, predictions_prefix, labels): + super().__init__() + self._predictions_prefix = predictions_prefix + self._labels = labels + + @classmethod + def create_from_config(cls, config, **kwargs): + return cls( + predictions_prefix=config.get('predictions_prefix'), + labels=config.get('labels') + ) + + def forward(self, inputs): # use log soft max + logits = inputs[self._predictions_prefix] + candidates = inputs[self._labels] + + assert len(logits.shape) in [2, 3] + + batch_size = logits.shape[0] + seq_len = logits.shape[1] + + if len(logits.shape) == 3: + loss = -torch.gather( + torch.log_softmax(logits, dim=-1).reshape(batch_size * seq_len, logits.shape[-1]), + dim=-1, + index=candidates.reshape(batch_size * seq_len, 1) + ).mean() + else: + loss = -torch.gather( + torch.log_softmax(logits, dim=-1), + dim=-1, + index=candidates.reshape(batch_size, 1) + ).mean() + + return loss + class BatchLogSoftmaxLoss(TorchLoss, config_name='batch_logsoftmax'): @@ -104,6 +141,45 @@ def forward(self, inputs): inputs[self._output_prefix] = loss.cpu().item() return loss + +class RqVaeLoss(TorchLoss, config_name='rqvae_loss'): + + def __init__(self, beta, output_prefix=None): + super().__init__() + self.beta = beta + self._output_prefix = output_prefix + + self._loss = nn.MSELoss() + + @classmethod + def create_from_config(cls, config, **kwargs): + # 0.25 is default Beta in paper + return cls( + beta = config.get('beta', 0.25), + output_prefix = config['output_prefix'], + ) + + def forward(self, inputs): + embeddings = inputs["embeddings"] + embeddings_restored = inputs["embeddings_restored"] + remainders = inputs["remainders"] + codebooks_vectors = inputs["codebooks_vectors"] + + rqvae_loss = 0 + + for remainder, codebook_vectors in zip(remainders, codebooks_vectors): + rqvae_loss += self.beta * self._loss( + remainder, codebook_vectors.detach() + ) + rqvae_loss += self._loss(codebook_vectors, remainder.detach()) + + recon_loss = self._loss(embeddings_restored, embeddings) + loss = (recon_loss + rqvae_loss).mean(dim=0) + + if self._output_prefix is not None: + inputs[self._output_prefix] = loss.cpu().item() + + return loss class BinaryCrossEntropyLoss(TorchLoss, config_name='bce'): @@ -241,6 +317,29 @@ def forward(self, inputs): return loss +class SASRecRealLoss(TorchLoss, config_name="sasrec_real"): + def __init__(self, positive_prefix, negative_prefix, output_prefix=None): + super().__init__() + self._positive_prefix = positive_prefix + self._negative_prefix = negative_prefix + self._output_prefix = output_prefix + + def forward(self, inputs): + positive_scores = inputs[self._positive_prefix] # (x) + negative_scores = inputs[self._negative_prefix] # (x) + assert positive_scores.shape[0] == negative_scores.shape[0] + + positive_loss = torch.log(nn.functional.sigmoid(positive_scores) + 1e-9) # (x) + negative_loss = torch.log(1.0 - nn.functional.sigmoid(negative_scores) + 1e-9) # (x) + loss = positive_loss + negative_loss # (x) + loss = -loss.mean() # (1) + + if self._output_prefix is not None: + inputs[self._output_prefix] = loss.cpu().item() + + return loss + + class SASRecLoss(TorchLoss, config_name='sasrec'): def __init__( @@ -257,10 +356,28 @@ def __init__( def forward(self, inputs): positive_scores = inputs[self._positive_prefix] # (x, embedding_dim) negative_scores = inputs[self._negative_prefix] # (x, embedding_dim) + sample_ids = inputs["sample_ids"] + + num_items = negative_scores.shape[1] - 2 + + possible_indices = torch.arange(1, num_items + 1, device=negative_scores.device) # 1, 2, ... num_items + mask = torch.ones_like(possible_indices, dtype=torch.bool) # True, True, ... True + mask[sample_ids - 1] = False # True, False, ... False, True, ... True + valid_indices = possible_indices[mask] # 1, 2, ... num_items, except sample_ids + + rand_idx = torch.randint(0, len(valid_indices), size=(negative_scores.shape[0], 1), device=negative_scores.device) + index = valid_indices[rand_idx] + + negative_scores = torch.gather( + input=negative_scores, + dim=1, + index=index, + ) + assert positive_scores.shape[0] == negative_scores.shape[0] positive_loss = torch.log(nn.functional.sigmoid(positive_scores)).sum(dim=-1) # (x) - negative_loss = torch.log(1.0 - nn.functional.sigmoid(negative_scores)).sum(dim=-1) # (x) + negative_loss = torch.log(1.0 - nn.functional.sigmoid(negative_scores) + 1e-9).sum(dim=-1) # (x), added 1e-9 for Tiger baseline loss = positive_loss + negative_loss # (x) loss = -loss.sum() # (1) @@ -287,8 +404,13 @@ def __init__( def forward(self, inputs): queries_embeddings = inputs[self._queries_prefix] # (batch_size, embedding_dim) - positive_embeddings = inputs[self._positive_prefix] # (batch_size, embedding_dim) - negative_embeddings = inputs[self._negative_prefix] # (num_negatives, embedding_dim) or (batch_size, num_negatives, embedding_dim) + # TODOPK check + positive_ids, positive_embeddings = inputs[ + self._positive_prefix + ] # (batch_size, embedding_dim) + negative_ids, negative_embeddings = inputs[ + self._negative_prefix + ] # (num_negatives, embedding_dim) or (batch_size, num_negatives, embedding_dim) # b -- batch_size, d -- embedding_dim positive_scores = torch.einsum( @@ -304,6 +426,15 @@ def forward(self, inputs): queries_embeddings, negative_embeddings ) # (batch_size, num_negatives) + + all_scores = torch.cat( + [positive_scores, negative_scores], dim=1 + ) # (batch_size, 1 + num_negatives) + logits = torch.log_softmax( + all_scores, dim=1 + ) # (batch_size, 1 + num_negatives) + loss = (-logits)[:, 0] # (batch_size) + loss = loss.mean() # (1) else: assert negative_embeddings.dim() == 3 # (batch_size, num_negatives, embedding_dim) # b -- batch_size, n -- num_negatives, d -- embedding_dim @@ -312,11 +443,8 @@ def forward(self, inputs): queries_embeddings, negative_embeddings ) # (batch_size, num_negatives) - all_scores = torch.cat([positive_scores, negative_scores], dim=1) # (batch_size, 1 + num_negatives) - logits = torch.log_softmax(all_scores, dim=1) # (batch_size, 1 + num_negatives) - loss = (-logits)[:, 0] # (batch_size) - loss = loss.mean() # (1) + assert False, "ask Vladimir wtf is it " if self._output_prefix is not None: inputs[self._output_prefix] = loss.cpu().item() diff --git a/modeling/models/__init__.py b/modeling/models/__init__.py index b4341e75..71fc9643 100644 --- a/modeling/models/__init__.py +++ b/modeling/models/__init__.py @@ -12,6 +12,11 @@ from .pop import PopModel from .pure_mf import PureMF from .random import RandomModel -from .sasrec import SasRecModel, SasRecInBatchModel -from .sasrec_ce import SasRecCeModel +from .rqvae import RqVaeModel from .s3rec import S3RecModel +from .sasrec_ce import SasRecCeModel +from .sasrec_full import SasRecFullModel +from .sasrec_in_batch import SasRecInBatchModel +from .sasrec_real import SasRecRealModel +from .sasrec_semantic import SasRecSemanticModel +from .tiger import TigerModel diff --git a/modeling/models/base.py b/modeling/models/base.py index e09a0eb4..a1384384 100644 --- a/modeling/models/base.py +++ b/modeling/models/base.py @@ -1,9 +1,6 @@ -from utils import MetaParent - -from utils import DEVICE, create_masked_tensor, get_activation_function - import torch import torch.nn as nn +from utils import DEVICE, MetaParent, create_masked_tensor, get_activation_function class BaseModel(metaclass=MetaParent): @@ -27,6 +24,13 @@ def _init_weights(self, initializer_range): ) elif 'bias' in key: nn.init.zeros_(value.data) + elif 'codebook' in key: + nn.init.trunc_normal_( + value.data, + std=initializer_range, + a=-2 * initializer_range, + b=2 * initializer_range + ) else: raise ValueError(f'Unknown transformer weight: {key}') @@ -90,9 +94,18 @@ def __init__( batch_first=True ) self._encoder = nn.TransformerEncoder(transformer_encoder_layer, num_layers) + + def get_item_embeddings(self, events): + return self._item_embeddings(events) + + def _apply_sequential_encoder( + self, events, lengths, add_cls_token=False, user_embeddings=None + ): + embeddings = self.get_item_embeddings( + events + ) # (all_batch_events, embedding_dim) - def _apply_sequential_encoder(self, events, lengths, add_cls_token=False): - embeddings = self._item_embeddings(events) # (all_batch_events, embedding_dim) + assert embeddings.shape[0] == sum(lengths) embeddings, mask = create_masked_tensor( data=embeddings, @@ -102,17 +115,7 @@ def _apply_sequential_encoder(self, events, lengths, add_cls_token=False): batch_size = mask.shape[0] seq_len = mask.shape[1] - positions = torch.arange( - start=seq_len - 1, end=-1, step=-1, device=mask.device - )[None].tile([batch_size, 1]).long() # (batch_size, seq_len) - positions_mask = positions < lengths[:, None] # (batch_size, max_seq_len) - - positions = positions[positions_mask] # (all_batch_events) - position_embeddings = self._position_embeddings(positions) # (all_batch_events, embedding_dim) - position_embeddings, _ = create_masked_tensor( - data=position_embeddings, - lengths=lengths - ) # (batch_size, seq_len, embedding_dim) + position_embeddings = self._encoder_pos_embeddings(lengths, mask) assert torch.allclose(position_embeddings[~mask], embeddings[~mask]) embeddings = embeddings + position_embeddings # (batch_size, seq_len, embedding_dim) @@ -128,6 +131,14 @@ def _apply_sequential_encoder(self, events, lengths, add_cls_token=False): embeddings = torch.cat((cls_token_expanded, embeddings), dim=1) mask = torch.cat((torch.ones((batch_size, 1), dtype=torch.bool, device=DEVICE), mask), dim=1) + if user_embeddings is not None: + embeddings = torch.cat((user_embeddings.unsqueeze(1), embeddings), dim=1) + mask = torch.cat( + (torch.ones((batch_size, 1), dtype=torch.bool, device=DEVICE), mask), + dim=1, + ) + seq_len += 1 # TODOPK ask if this is correct + if self._is_causal: causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool().to(DEVICE) # (seq_len, seq_len) embeddings = self._encoder( @@ -143,6 +154,23 @@ def _apply_sequential_encoder(self, events, lengths, add_cls_token=False): return embeddings, mask + def _encoder_pos_embeddings(self, lengths, mask): + batch_size = mask.shape[0] + seq_len = mask.shape[1] + + positions = torch.arange( + start=seq_len - 1, end=-1, step=-1, device=mask.device + )[None].tile([batch_size, 1]).long() # (batch_size, seq_len) + positions_mask = positions < lengths[:, None] # (batch_size, max_seq_len) + + positions = positions[positions_mask] # (all_batch_events) + position_embeddings = self._position_embeddings(positions) # (all_batch_events, embedding_dim) + position_embeddings, _ = create_masked_tensor( + data=position_embeddings, + lengths=lengths + ) # (batch_size, seq_len, embedding_dim) + return position_embeddings + @staticmethod def _add_cls_token(items, lengths, cls_token_id=0): num_items = items.shape[0] diff --git a/modeling/models/rqvae.py b/modeling/models/rqvae.py new file mode 100644 index 00000000..fc70b890 --- /dev/null +++ b/modeling/models/rqvae.py @@ -0,0 +1,163 @@ +import functools +from utils import DEVICE +from models.base import TorchModel + +import torch +import faiss + +class RqVaeModel(TorchModel, config_name='rqvae'): + + def __init__( + self, + train_sampler, + input_dim: int, + hidden_dim: int, + n_iter: int, + codebook_sizes: list[int], + should_init_codebooks, + should_reinit_unused_clusters, + initializer_range + ): + super().__init__() + + self.n_iter = n_iter + + # Kmeans initialization + self.should_init_codebooks = should_init_codebooks + + # Trick with re-initing empty clusters + self.should_reinit_unused_clusters = should_reinit_unused_clusters + + # Enc and dec are mirrored copies of each other + self.encoder = self.make_encoding_tower(input_dim, hidden_dim) + self.decoder = self.make_encoding_tower(hidden_dim, input_dim) + + # Default initialization of codebook + self.codebooks = torch.nn.ParameterList() + + self.codebook_sizes = codebook_sizes + + for codebook_size in codebook_sizes: + cb = torch.FloatTensor(codebook_size, hidden_dim) + self.codebooks.append(cb) + + self._init_weights(initializer_range) + + if self.should_init_codebooks: + if train_sampler is None: + raise AttributeError("Train sampler is None") + + embeddings = torch.stack([entry['item.embed'] for entry in train_sampler._dataset]) + self.init_codebooks(embeddings) + print('Codebooks initialized with Faiss Kmeans') + self.should_init_codebooks = False + + @classmethod + def create_from_config(cls, config, **kwargs): + return cls( + train_sampler=kwargs.get('train_sampler'), + input_dim=config['embedding_dim'], + hidden_dim=config['hidden_dim'], + n_iter=config['n_iter'], + codebook_sizes=config['codebook_sizes'], + should_init_codebooks=config.get('should_init_codebooks', False), + should_reinit_unused_clusters=config.get('should_reinit_unused_clusters', False), + initializer_range=config.get('initializer_range', 0.02) + ) + + def make_encoding_tower(self, d1: int, d2: int): + return torch.nn.Linear(d1, d2, bias=False) + + @staticmethod + def get_codebook_indices(remainder, codebook): + dist = torch.cdist(remainder, codebook) + return dist.argmin(dim=-1) + + def init_codebooks(self, embeddings): + with torch.no_grad(): + remainder = self.encoder(embeddings) + for codebook in self.codebooks: + embeddings_np = remainder.cpu().numpy() + n_clusters = codebook.shape[0] + + kmeans = faiss.Kmeans( + d=embeddings_np.shape[1], + k=n_clusters, + niter=self.n_iter, + ) + kmeans.train(embeddings_np) + + codebook.data = torch.from_numpy(kmeans.centroids).to(codebook.device) + + codebook_indices = self.get_codebook_indices(remainder, codebook) + codebook_vectors = codebook[codebook_indices] + remainder = remainder - codebook_vectors + + @staticmethod + def reinit_unused_clusters(remainder, codebook, codebook_indices): + with torch.no_grad(): + is_used = torch.full((codebook.shape[0],), False, device=codebook.device) + unique_indices = codebook_indices.unique() + is_used[unique_indices] = True + rand_input = torch.randint(0, remainder.shape[0], ((~is_used).sum(),)) + codebook[~is_used] = remainder[rand_input] + + def train_pass(self, embeddings): + latent_vector = self.encoder(embeddings) + + latent_restored = 0 + + num_unique_clusters = [] + remainder = latent_vector + + remainders = [] + codebooks_vectors = [] + + for codebook in self.codebooks: + remainders.append(remainder) + + codebook_indices = self.get_codebook_indices(remainder, codebook) + codebook_vectors = codebook[codebook_indices] + + if self.should_reinit_unused_clusters: + self.reinit_unused_clusters(remainder, codebook, codebook_indices) + + num_unique_clusters.append(codebook_indices.unique().shape[0]) + + codebooks_vectors.append(codebook_vectors) + + latent_restored = latent_restored + codebook_vectors + remainder = remainder - codebook_vectors + + # Here we cast recon loss to latent vector + latent_restored = latent_vector + (latent_restored - latent_vector).detach() + embeddings_restored = self.decoder(latent_restored) + + return { + "embeddings": embeddings, + "embeddings_restored": embeddings_restored, + "remainders": remainders, + "codebooks_vectors": codebooks_vectors + } + + def eval_pass(self, embeddings): + ind_lists = [] + remainder = self.encoder(embeddings) + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + codebook_vectors = codebook[codebook_indices] + ind_lists.append(codebook_indices.cpu().numpy()) + remainder = remainder - codebook_vectors + return torch.tensor(list(zip(*ind_lists))).to(DEVICE), remainder + + def forward(self, inputs): + embeddings = inputs["embeddings"] + + if self.training: # training mode + return self.train_pass(embeddings) + else: # eval mode + return self.eval_pass(embeddings) + + @functools.cache + def get_single_embedding(self, codebook_idx: int, codebook_id: int): + return self.codebooks[codebook_idx][codebook_id] diff --git a/modeling/models/sasrec.py b/modeling/models/sasrec.py deleted file mode 100644 index 1ef5c9e7..00000000 --- a/modeling/models/sasrec.py +++ /dev/null @@ -1,219 +0,0 @@ -from models import SequentialTorchModel -from utils import create_masked_tensor - -import torch - - -class SasRecModel(SequentialTorchModel, config_name='sasrec'): - - def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02 - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=True - ) - self._sequence_prefix = sequence_prefix - self._positive_prefix = positive_prefix - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get('num_heads', int(config['embedding_dim'] // 64)), - num_layers=config['num_layers'], - dim_feedforward=config.get('dim_feedforward', 4 * config['embedding_dim']), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02) - ) - - def forward(self, inputs): - all_sample_events = inputs['{}.ids'.format(self._sequence_prefix)] # (all_batch_events) - all_sample_lengths = inputs['{}.length'.format(self._sequence_prefix)] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, all_sample_lengths - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - if self.training: # training mode - all_positive_sample_events = inputs['{}.ids'.format(self._positive_prefix)] # (all_batch_events) - - all_sample_embeddings = embeddings[mask] # (all_batch_events, embedding_dim) - - all_embeddings = self._item_embeddings.weight # (num_items + 2, embedding_dim) - - # a -- all_batch_events, n -- num_items + 2, d -- embedding_dim - all_scores = torch.einsum( - 'ad,nd->an', - all_sample_embeddings, - all_embeddings - ) # (all_batch_events, num_items + 2) - - positive_scores = torch.gather( - input=all_scores, - dim=1, - index=all_positive_sample_events[..., None] - ) # (all_batch_items, 1) - - sample_ids, _ = create_masked_tensor( - data=all_sample_events, - lengths=all_sample_lengths - ) # (batch_size, seq_len) - - sample_ids = torch.repeat_interleave(sample_ids, all_sample_lengths, dim=0) # (all_batch_events, seq_len) - - negative_scores = torch.scatter( - input=all_scores, - dim=1, - index=sample_ids, - src=torch.ones_like(sample_ids) * (-torch.inf) - ) # (all_batch_events, num_items + 2) - negative_scores[:, 0] = -torch.inf # Padding idx - negative_scores[:, self._num_items + 1:] = -torch.inf # Mask idx - - return { - 'positive_scores': positive_scores, - 'negative_scores': negative_scores - } - else: # eval mode - last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim) - # b - batch_size, n - num_candidates, d - embedding_dim - candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight - ) # (batch_size, num_items + 2) - candidate_scores[:, 0] = -torch.inf # Padding id - candidate_scores[:, self._num_items + 1:] = -torch.inf # Mask id - - _, indices = torch.topk( - candidate_scores, - k=20, dim=-1, largest=True - ) # (batch_size, 20) - - return indices - - -class SasRecInBatchModel(SasRecModel, config_name='sasrec_in_batch'): - - def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02 - ): - super().__init__( - sequence_prefix=sequence_prefix, - positive_prefix=positive_prefix, - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - initializer_range=initializer_range - ) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get('num_heads', int(config['embedding_dim'] // 64)), - num_layers=config['num_layers'], - dim_feedforward=config.get('dim_feedforward', 4 * config['embedding_dim']), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02) - ) - - def forward(self, inputs): - all_sample_events = inputs['{}.ids'.format(self._sequence_prefix)] # (all_batch_events) - all_sample_lengths = inputs['{}.length'.format(self._sequence_prefix)] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, all_sample_lengths - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - if self.training: # training mode - # queries - in_batch_queries_embeddings = embeddings[mask] # (all_batch_events, embedding_dim) - - # positives - in_batch_positive_events = inputs['{}.ids'.format(self._positive_prefix)] # (all_batch_events) - in_batch_positive_embeddings = self._item_embeddings( - in_batch_positive_events - ) # (all_batch_events, embedding_dim) - - # negatives - batch_size = all_sample_lengths.shape[0] - random_ids = torch.randperm(in_batch_positive_events.shape[0]) - in_batch_negative_ids = in_batch_positive_events[random_ids][:batch_size] - - in_batch_negative_embeddings = self._item_embeddings( - in_batch_negative_ids - ) # (batch_size, embedding_dim) - - return { - 'query_embeddings': in_batch_queries_embeddings, - 'positive_embeddings': in_batch_positive_embeddings, - 'negative_embeddings': in_batch_negative_embeddings - } - else: # eval mode - last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim) - - # b - batch_size, n - num_candidates, d - embedding_dim - candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight - ) # (batch_size, num_items + 2) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1:] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, dim=-1, largest=True - ) # (batch_size, 20) - - return indices diff --git a/modeling/models/sasrec_full.py b/modeling/models/sasrec_full.py new file mode 100644 index 00000000..5d484319 --- /dev/null +++ b/modeling/models/sasrec_full.py @@ -0,0 +1,92 @@ +import torch +from models.base import SequentialTorchModel + + +class SasRecFullModel(SequentialTorchModel, config_name="sasrec_full"): + def __init__( + self, + sequence_prefix, + positive_prefix, + num_items, + max_sequence_length, + embedding_dim, + num_heads, + num_layers, + dim_feedforward, + dropout=0.0, + activation="relu", + layer_norm_eps=1e-9, + initializer_range=0.02, + ): + super().__init__( + num_items=num_items, + max_sequence_length=max_sequence_length, + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + is_causal=True, + ) + self._sequence_prefix = sequence_prefix + self._positive_prefix = positive_prefix + + self._init_weights(initializer_range) + + @classmethod + def create_from_config(cls, config, **kwargs): + return cls( + sequence_prefix=config["sequence_prefix"], + positive_prefix=config["positive_prefix"], + num_items=kwargs["num_items"], + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), + num_layers=config["num_layers"], + dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), + dropout=config.get("dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), + ) + + def forward(self, inputs): + all_sample_events = inputs[ + "{}.ids".format(self._sequence_prefix) + ] # (all_batch_events) + all_sample_lengths = inputs[ + "{}.length".format(self._sequence_prefix) + ] # (batch_size) + + embeddings, mask = self._apply_sequential_encoder( + all_sample_events, all_sample_lengths + ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) + + last_embeddings = self._get_last_embedding( + embeddings, mask + ) # (batch_size, embedding_dim) + + if self.training: # training mode + all_scores = torch.einsum( + "bd,nd->bn", last_embeddings, self._item_embeddings.weight + ) # (all_batch_events, num_items + 2) + + # positives + in_batch_positive_events = inputs[ + "{}.ids".format(self._positive_prefix) + ] # (all_batch_events) + + return {"labels.ids": in_batch_positive_events, "logits": all_scores} + else: # eval mode + # b - batch_size, n - num_candidates, d - embedding_dim + candidate_scores = torch.einsum( + "bd,nd->bn", last_embeddings, self._item_embeddings.weight + ) # (batch_size, num_items + 2) + candidate_scores[:, 0] = -torch.inf + candidate_scores[:, self._num_items + 1 :] = -torch.inf + + _, indices = torch.topk( + candidate_scores, k=20, dim=-1, largest=True + ) # (batch_size, 20) + + return indices diff --git a/modeling/models/sasrec_in_batch.py b/modeling/models/sasrec_in_batch.py new file mode 100644 index 00000000..e9cfb930 --- /dev/null +++ b/modeling/models/sasrec_in_batch.py @@ -0,0 +1,119 @@ +import torch +from models.base import SequentialTorchModel + + +class SasRecInBatchModel(SequentialTorchModel, config_name="sasrec_in_batch"): + def __init__( + self, + sequence_prefix, + positive_prefix, + num_items, + max_sequence_length, + embedding_dim, + num_heads, + num_layers, + dim_feedforward, + num_in_batch_negatives=-1, + dropout=0.0, + activation="relu", + layer_norm_eps=1e-9, + initializer_range=0.02, + ): + super().__init__( + num_items=num_items, + max_sequence_length=max_sequence_length, + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + is_causal=True, + ) + self._sequence_prefix = sequence_prefix + self._positive_prefix = positive_prefix + self._num_in_batch_negatives = num_in_batch_negatives + self._init_weights(initializer_range) + + @classmethod + def create_from_config(cls, config, **kwargs): + return cls( + sequence_prefix=config["sequence_prefix"], + positive_prefix=config["positive_prefix"], + num_items=kwargs["num_items"], + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), + num_layers=config["num_layers"], + num_in_batch_negatives=config.get("num_in_batch_negatives", -1), + dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), + dropout=config.get("dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), + ) + + def forward(self, inputs): + all_sample_events = inputs[ + "{}.ids".format(self._sequence_prefix) + ] # (all_batch_events) + all_sample_lengths = inputs[ + "{}.length".format(self._sequence_prefix) + ] # (batch_size) + + embeddings, mask = self._apply_sequential_encoder( + all_sample_events, all_sample_lengths + ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) + + last_embeddings = self._get_last_embedding( + embeddings, mask + ) # (batch_size, embedding_dim) + + if self.training: # training mode + # positives + in_batch_positive_events = inputs[ + "{}.ids".format(self._positive_prefix) + ] # (all_batch_events) + in_batch_positive_embeddings = self._item_embeddings( + in_batch_positive_events + ) # (all_batch_events, embedding_dim) + + # negatives + num_in_batch_negatives = self._num_in_batch_negatives + batch_size = all_sample_lengths.shape[0] + random_ids = torch.randperm(in_batch_positive_events.shape[0]) + if num_in_batch_negatives == -1: + num_in_batch_negatives = batch_size + in_batch_negative_ids = in_batch_positive_events[ + random_ids + ][ + :num_in_batch_negatives + ] + + in_batch_negative_embeddings = self._item_embeddings( + in_batch_negative_ids + ) # (num_in_batch_negatives, embedding_dim) + + return { + "query_embeddings": last_embeddings, + "positive_embeddings": ( + in_batch_positive_events, + in_batch_positive_embeddings, + ), + "negative_embeddings": ( + in_batch_negative_ids, + in_batch_negative_embeddings, + ), + } + else: # eval mode + # b - batch_size, n - num_candidates, d - embedding_dim + candidate_scores = torch.einsum( + "bd,nd->bn", last_embeddings, self._item_embeddings.weight + ) # (batch_size, num_items + 2) + candidate_scores[:, 0] = -torch.inf + candidate_scores[:, self._num_items + 1 :] = -torch.inf + + _, indices = torch.topk( + candidate_scores, k=20, dim=-1, largest=True + ) # (batch_size, 20) + + return indices diff --git a/modeling/models/sasrec_real.py b/modeling/models/sasrec_real.py new file mode 100644 index 00000000..d87743f4 --- /dev/null +++ b/modeling/models/sasrec_real.py @@ -0,0 +1,111 @@ +import torch +from models.base import SequentialTorchModel + + +class SasRecRealModel(SequentialTorchModel, config_name="sasrec_real"): + def __init__( + self, + sequence_prefix, + positive_prefix, + negative_prefix, + num_items, + max_sequence_length, + embedding_dim, + num_heads, + num_layers, + dim_feedforward, + dropout=0.0, + activation="relu", + layer_norm_eps=1e-9, + initializer_range=0.02, + ): + super().__init__( + num_items=num_items, + max_sequence_length=max_sequence_length, + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + is_causal=True, + ) + self._sequence_prefix = sequence_prefix + self._positive_prefix = positive_prefix + self._negative_prefix = negative_prefix + + self._init_weights(initializer_range) + + @classmethod + def create_from_config(cls, config, **kwargs): + return cls( + sequence_prefix=config["sequence_prefix"], + positive_prefix=config["positive_prefix"], + negative_prefix=config["negative_prefix"], + num_items=kwargs["num_items"], + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), + num_layers=config["num_layers"], + dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), + dropout=config.get("dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), + ) + + def forward(self, inputs): + all_sample_events = inputs[ + "{}.ids".format(self._sequence_prefix) + ] # (all_batch_events) + all_sample_lengths = inputs[ + "{}.length".format(self._sequence_prefix) + ] # (batch_size) + + embeddings, mask = self._apply_sequential_encoder( + all_sample_events, all_sample_lengths + ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) + + last_embeddings = self._get_last_embedding( + embeddings, mask + ) # (batch_size, embedding_dim) + + if self.training: # training mode + # positives + in_batch_positive_events = inputs[ + "{}.ids".format(self._positive_prefix) + ] # (all_batch_events) + in_batch_positive_embeddings = self._item_embeddings( + in_batch_positive_events + ) # (all_batch_events, embedding_dim) + positive_scores = torch.einsum( + "bd,bd->b", last_embeddings, in_batch_positive_embeddings + ) # (all_batch_events) + + # negatives + in_batch_negative_events = inputs[ + "{}.ids".format(self._negative_prefix) + ] # (all_batch_events) + in_batch_negative_embeddings = self._item_embeddings( + in_batch_negative_events + ) # (all_batch_events, embedding_dim) + negative_scores = torch.einsum( + "bd,bd->b", last_embeddings, in_batch_negative_embeddings + ) # (all_batch_events) + + return { + "positive_scores": positive_scores, + "negative_scores": negative_scores, + } + else: # eval mode + # b - batch_size, n - num_candidates, d - embedding_dim + candidate_scores = torch.einsum( + "bd,nd->bn", last_embeddings, self._item_embeddings.weight + ) # (batch_size, num_items + 2) + candidate_scores[:, 0] = -torch.inf + candidate_scores[:, self._num_items + 1 :] = -torch.inf + + _, indices = torch.topk( + candidate_scores, k=20, dim=-1, largest=True + ) # (batch_size, 20) + + return indices diff --git a/modeling/models/sasrec_semantic.py b/modeling/models/sasrec_semantic.py new file mode 100644 index 00000000..0230f41e --- /dev/null +++ b/modeling/models/sasrec_semantic.py @@ -0,0 +1,239 @@ +import torch +from .tiger import TigerModel +from models import SequentialTorchModel +from torch import nn +from utils import DEVICE, create_masked_tensor +from torch import nn + + +class SasRecSemanticModel(SequentialTorchModel, config_name="sasrec_semantic"): + def __init__( + self, + rqvae_model, + item_id_to_semantic_id, + item_id_to_residual, + sequence_prefix, + positive_prefix, + negative_prefix, + num_items, + num_users, + max_sequence_length, + embedding_dim, + num_heads, + num_layers, + dim_feedforward, + dropout=0.0, + activation="relu", + layer_norm_eps=1e-9, + initializer_range=0.02, + ): + super().__init__( + num_items=num_items, + max_sequence_length=max_sequence_length, + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + is_causal=True, + ) + self._sequence_prefix = sequence_prefix + self._positive_prefix = positive_prefix + self._negative_prefix = negative_prefix + + self._num_users = num_users + + self._codebook_sizes = rqvae_model.codebook_sizes + + self._codebook_embeddings = nn.Embedding( + num_embeddings=len(self._codebook_sizes) + 2, embedding_dim=embedding_dim + ) # + 2 for bos token & residual + + self._user_embeddings = nn.Embedding( + num_embeddings=self._num_users + 1, embedding_dim=embedding_dim + ) + + self._init_weights(initializer_range) + + self._item_id_to_semantic_id = ( + item_id_to_semantic_id # len(num_items), len(self._codebook_sizes) + ) + self._item_id_to_residual = item_id_to_residual # len(num_items), embedding_dim + + self._codebooks = nn.Parameter( + torch.stack([codebook for codebook in rqvae_model.codebooks]), + requires_grad=True, + ) # len(self._codebook_sizes), codebook_size, embedding_dim + + @classmethod + def create_from_config(cls, config, **kwargs): + rqvae_model, semantic_ids, residuals, _ = TigerModel.init_rqvae(config) + + return cls( + rqvae_model=rqvae_model, + item_id_to_semantic_id=semantic_ids, + item_id_to_residual=residuals, + sequence_prefix=config["sequence_prefix"], + positive_prefix=config["positive_prefix"], + negative_prefix=config["negative_prefix"], + num_items=kwargs["num_items"], + num_users=kwargs["num_users"], + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), + num_layers=config["num_layers"], + dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), + dropout=config.get("dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), + ) + + def forward(self, inputs): + all_sample_events = inputs[ + "{}.ids".format(self._sequence_prefix) + ] # (all_batch_events) + all_sample_lengths = inputs[ + "{}.length".format(self._sequence_prefix) + ] # (batch_size) + + user_embeddings = self._user_embeddings(inputs["user.ids"]) + + embeddings, mask = self._apply_sequential_encoder( + all_sample_events, + all_sample_lengths * (len(self._codebook_sizes) + 1), + user_embeddings=user_embeddings, + ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) + + last_embeddings = self._get_last_embedding( + embeddings, mask + ) # (batch_size, embedding_dim) + + if self.training: # training mode + # positives + in_batch_positive_events = inputs[ + "{}.ids".format(self._positive_prefix) + ] # (all_batch_events) + in_batch_positive_embeddings = self.get_embeddings( + in_batch_positive_events - 1 + ).sum(dim=1) # (all_batch_events, embedding_dim) + positive_scores = torch.einsum( + "bd,bd->b", last_embeddings, in_batch_positive_embeddings + ) # (all_batch_events) + + # TODOPK normalize in all models embeddings for stability + + # negatives + in_batch_negative_events = inputs[ + "{}.ids".format(self._negative_prefix) + ] # (all_batch_events) + in_batch_negative_embeddings = self.get_embeddings( + in_batch_negative_events - 1 + ).sum(dim=1) # (all_batch_events, embedding_dim) + negative_scores = torch.einsum( + "bd,bd->b", last_embeddings, in_batch_negative_embeddings + ) # (all_batch_events) + + return { + "positive_scores": positive_scores, + "negative_scores": negative_scores, + } + else: # eval mode + item_embeddings = self.get_embeddings(torch.arange(self._num_items)).sum( + dim=1 + ) # num_items, embedding_dim + # b - batch_size, n - num_candidates, d - embedding_dim + candidate_scores = torch.einsum( + "bd,nd->bn", + last_embeddings, + item_embeddings, + ) # (batch_size, num_items) + + _, indices = torch.topk( + candidate_scores, k=20, dim=-1, largest=True + ) # (batch_size, 20) + + return indices + 1 # tensors are 0 indexed + + def get_item_embeddings(self, events): + item_embeddings = self.get_embeddings( + events - 1 + ) # len(events), len(self._codebook_sizes) + 1, embedding_dim + return item_embeddings.reshape(-1, self._embedding_dim) + + def get_embeddings(self, events): # events = 0 ... num_items - 1 + semantic_ids = self._item_id_to_semantic_id[ + events + ] # len(events), len(self._codebook_sizes) + residuals = self._item_id_to_residual[events] # len(events), embedding_dim + + semantic_embeddings = torch.stack( + [ + codebook[semantic_ids[:, i]] + for i, codebook in enumerate(self._codebooks) + ], + dim=1, + ) # len(events), len(self._codebook_sizes), embedding_dim + + residual = residuals.unsqueeze(1) + + # get true item embeddings + item_embeddings = torch.cat( + [semantic_embeddings, residual], dim=1 + ) # len(events), len(self._codebook_sizes) + 1, embedding_dim + + return item_embeddings + + def _encoder_pos_embeddings(self, lengths, mask): + def position_lambda(x): + return x // ( + len(self._codebook_sizes) + 1 + ) # 5 5 5 5 4 4 4 4 ..., +1 for residual + + position_embeddings = self._get_position_embeddings( + lengths, mask, position_lambda, self._position_embeddings + ) + + def codebook_lambda(x): + x = len(self._codebook_sizes) - x % (len(self._codebook_sizes) + 1) + x[x == len(self._codebook_sizes)] = len(self._codebook_sizes) + 1 + # 0 1 2 4 0 1 2 4 ... # len(self._codebook_sizes) + 1 = 4 for residual + return x + + codebook_embeddings = self._get_position_embeddings( + lengths, mask, codebook_lambda, self._codebook_embeddings + ) + + return position_embeddings + codebook_embeddings + + def _get_position_embeddings(self, lengths, mask, position_lambda, embedding_layer): + batch_size = mask.shape[0] + seq_len = mask.shape[1] + + positions = ( + torch.arange(start=seq_len - 1, end=-1, step=-1, device=DEVICE)[None] + .tile([batch_size, 1]) + .long() + ) # (batch_size, seq_len) + positions_mask = positions < lengths[:, None] # (batch_size, max_seq_len) + + positions = positions[positions_mask] # (all_batch_events) + # 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 7 6 5 4 3 2 1 0 ... + + positions = position_lambda(positions) # (all_batch_events) + + # print(f"{positions.tolist()[:20]=}") + + assert (positions >= 0).all() and ( + positions < embedding_layer.num_embeddings + ).all() + + position_embeddings = embedding_layer( + positions + ) # (all_batch_events, embedding_dim) + + position_embeddings, _ = create_masked_tensor( + data=position_embeddings, lengths=lengths + ) # (batch_size, seq_len, embedding_dim) + + return position_embeddings diff --git a/modeling/models/tiger.py b/modeling/models/tiger.py new file mode 100644 index 00000000..81d5825d --- /dev/null +++ b/modeling/models/tiger.py @@ -0,0 +1,489 @@ +import json + +import torch +from models.base import SequentialTorchModel +from rqvae_utils import CollisionSolver, SimplifiedTree +from torch import nn +from utils import DEVICE, create_masked_tensor, get_activation_function + +from .rqvae import RqVaeModel + + +class TigerModel(SequentialTorchModel, config_name="tiger"): + def __init__( + self, + rqvae_model, + item_id_to_semantic_id, + item_id_to_residual, + solver, + sequence_prefix, + pred_prefix, + positive_prefix, + labels_prefix, + num_items, + max_sequence_length, + embedding_dim, + num_heads, + num_encoder_layers, + num_decoder_layers, + dim_feedforward, + dropout=0.0, + activation="relu", + layer_norm_eps=1e-9, + initializer_range=0.02, + ): + super().__init__( + num_items=num_items, + max_sequence_length=max_sequence_length, + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_encoder_layers, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + is_causal=True, + ) + + self._sequence_prefix = sequence_prefix + self._pred_prefix = pred_prefix + self._positive_prefix = positive_prefix + self._labels_prefix = labels_prefix + + transformer_decoder_layer = nn.TransformerDecoderLayer( + d_model=embedding_dim, + nhead=num_heads, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=get_activation_function(activation), + layer_norm_eps=layer_norm_eps, + batch_first=True, + ) + + self._decoder = nn.TransformerDecoder( + transformer_decoder_layer, num_decoder_layers + ) + + self._decoder_layernorm = nn.LayerNorm(embedding_dim, eps=layer_norm_eps) + self._decoder_dropout = nn.Dropout(dropout) + + self._solver: CollisionSolver = solver + + self._codebook_sizes = rqvae_model.codebook_sizes + self._bos_weight = nn.Parameter( + torch.nn.init.trunc_normal_( + torch.zeros(embedding_dim), + std=initializer_range, + a=-2 * initializer_range, + b=2 * initializer_range, + ), + requires_grad=True, # TODOPK added for bos + ) + + self._codebook_embeddings = nn.Embedding( + num_embeddings=len(self._codebook_sizes) + 2, embedding_dim=embedding_dim + ) # + 2 for bos token & residual + + self._init_weights(initializer_range) + + self._codebook_item_embeddings_stacked = nn.Parameter( + torch.stack([codebook for codebook in rqvae_model.codebooks]), + requires_grad=True, + ) # TODOPK (ask is it ok to have separate codebooks and _item_id_to_semantic_embedding) + + self._item_id_to_semantic_id = item_id_to_semantic_id + self._item_id_to_residual = item_id_to_residual + + self._item_id_to_semantic_embedding = nn.Parameter( + self.get_init_item_embeddings(item_id_to_semantic_id, item_id_to_residual), + requires_grad=True, + ) + + self._trie = SimplifiedTree(self._codebook_item_embeddings_stacked) + + self._trie.build_tree_structure( + item_id_to_semantic_id.to(DEVICE), + item_id_to_residual.to(DEVICE), + torch.arange(1, len(item_id_to_semantic_id) + 1).to(DEVICE), + sum_with_residuals=False, + ) + + @classmethod + def init_rqvae(cls, config): + rqvae_config = json.load(open(config["rqvae_train_config_path"])) + rqvae_config["model"]["should_init_codebooks"] = False + + rqvae_model = RqVaeModel.create_from_config(rqvae_config["model"]).to(DEVICE) + rqvae_model.load_state_dict( + torch.load(config["rqvae_checkpoint_path"], weights_only=True) + ) + rqvae_model.eval() + for param in rqvae_model.parameters(): + param.requires_grad = False + + codebook_sizes = rqvae_model.codebook_sizes + assert all([book_size == codebook_sizes[0] for book_size in codebook_sizes]) + + embs_extractor = torch.load(config["embs_extractor_path"], weights_only=False) + + embs_extractor = embs_extractor.sort_index() + + item_ids = embs_extractor.index.tolist() + assert item_ids == list(range(1, len(item_ids) + 1)) + + text_embeddings = torch.stack(embs_extractor["embeddings"].tolist()).to(DEVICE) + + semantic_ids, residuals = rqvae_model({"embeddings": text_embeddings}) + + return rqvae_model, semantic_ids, residuals, item_ids + + @classmethod + def create_from_config(cls, config, **kwargs): + rqvae_model, semantic_ids, residuals, item_ids = cls.init_rqvae(config) + + solver = CollisionSolver( + emb_dim=residuals.shape[1], + sem_id_len=len(rqvae_model.codebook_sizes), + codebook_size=rqvae_model.codebook_sizes[0], + ) + solver.create_query_candidates_dict( + torch.tensor(item_ids), semantic_ids, residuals + ) + + return cls( + rqvae_model=rqvae_model, + item_id_to_semantic_id=semantic_ids, + item_id_to_residual=residuals, + solver=solver, + sequence_prefix=config["sequence_prefix"], + pred_prefix=config["predictions_prefix"], + positive_prefix=config["positive_prefix"], + labels_prefix=config["labels_prefix"], + num_items=rqvae_model.codebook_sizes[0], # unused + max_sequence_length=kwargs["max_sequence_length"], + embedding_dim=config["embedding_dim"], + num_heads=config.get("num_heads", int(config["embedding_dim"] // 64)), + num_encoder_layers=config["num_encoder_layers"], + num_decoder_layers=config["num_decoder_layers"], + dim_feedforward=config.get("dim_feedforward", 4 * config["embedding_dim"]), + dropout=config.get("dropout", 0.0), + initializer_range=config.get("initializer_range", 0.02), + ) + + # semantic ids come with dedup token + def forward(self, inputs): + all_sample_events = inputs[ + "{}.ids".format(self._sequence_prefix) + ] # (all_batch_events) + all_sample_lengths = inputs[ + "{}.length".format(self._sequence_prefix) + ] # (batch_size) + + encoder_embeddings, encoder_mask = self._apply_sequential_encoder( + all_sample_events, all_sample_lengths * (len(self._codebook_sizes) + 1) + ) # (batch_size, enc_seq_len, embedding_dim), (batch_size, enc_seq_len) + + if self.training: + label_events = inputs["{}.ids".format(self._positive_prefix)] + label_lengths = inputs["{}.length".format(self._positive_prefix)] + + tgt_embeddings = self.get_item_embeddings( + label_events + ) # (all_batch_events, embedding_dim) + + decoder_outputs = self._apply_decoder( + tgt_embeddings, + label_lengths * (len(self._codebook_sizes) + 1), + encoder_embeddings, + encoder_mask, + ) # (batch_size, label_len, embedding_dim) + + decoder_prefix_scores = torch.einsum( + "bsd,scd->bsc", + decoder_outputs[:, :-1, :], + self._codebook_item_embeddings_stacked, + ) + + decoder_output_residual = decoder_outputs[:, -1, :] + + semantic_ids = self._item_id_to_semantic_id[ + label_events - 1 + ] # len(events), len(codebook_sizes) + true_residuals = self._item_id_to_residual[label_events - 1] + + true_info = self._solver.get_true_dedup_tokens(semantic_ids, true_residuals) + pred_info = self._solver.get_pred_scores( + semantic_ids, decoder_output_residual + ) + + return { + "logits": decoder_prefix_scores.reshape( + -1, decoder_prefix_scores.shape[2] + ), + "semantic.labels.ids": semantic_ids.reshape(-1), + "dedup.logits": pred_info["pred_scores"], + "dedup.labels.ids": true_info["true_dedup_tokens"], + } + # else: + # semantic_ids, tgt_embeddings = self._apply_decoder_autoregressive( + # encoder_embeddings, encoder_mask + # ) # (batch_size, len(self._codebook_sizes) (bos, residual)), (batch_size, len(self._codebook_sizes) + 2 (bos, residual), embedding_dim) + # TODOPK + # # 1 4 6 -> lookup -> sum = emb (last embedding) # bs, embedding_dim + # # take all embedings (from stacked) # all_items, embedding_dim + # # take from sasrec eval (indices + 1) + # # guarantee that all items are in correct order + + # residuals = tgt_embeddings[:, -1, :] + # semantic_ids = semantic_ids.to(torch.int64) + + # item_ids = self._trie.query(semantic_ids, items_to_query=20) + + # return item_ids + # TODOPK + # uid -> hash (murmurhash32) -> modulo (2000) -> get_embedding -> prepend + # first iteration -> for each user get embedding + + else: # eval mode + semantic_ids, tgt_embeddings = self._apply_decoder_autoregressive( + encoder_embeddings, encoder_mask + ) # (batch_size, len(self._codebook_sizes)), (batch_size, len(self._codebook_sizes) + 2, embedding_dim) + + embs = [] + for semantic_id in semantic_ids: + cur_emb = [] + for idx, codebook_id in enumerate(semantic_id): + cur_emb.append( + self._codebook_item_embeddings_stacked[idx][codebook_id.item()] + ) + embs.append(torch.stack(cur_emb)) + + last_embeddings = torch.stack(embs).sum(dim=1) # batch_size, embedding_dim + + candidate_scores = torch.einsum( + "bd,nd->bn", + last_embeddings, + self._item_id_to_semantic_embedding.sum(dim=1), + ) # (batch_size, num_items) + + _, indices = torch.topk( + candidate_scores, k=20, dim=-1, largest=True + ) # (batch_size, 20) + + return indices + 1 # tensors are 0 indexed + + def _apply_decoder( + self, tgt_embeddings, label_lengths, encoder_embeddings, encoder_mask + ): + tgt_embeddings, tgt_mask = create_masked_tensor( + data=tgt_embeddings, lengths=label_lengths + ) # (batch_size, dec_seq_len, embedding_dim), (batch_size, dec_seq_len) + + batch_size = tgt_embeddings.shape[0] + bos_embeddings = self._bos_weight.unsqueeze(0).expand( + batch_size, 1, -1 + ) # (batch_size, 1, embedding_dim) + + tgt_embeddings = torch.cat( + [bos_embeddings, tgt_embeddings[:, :-1, :]], dim=1 + ) # remove residual by using :-1 + + label_len = tgt_mask.shape[1] + + assert label_len == len(self._codebook_sizes) + 1 + + position_embeddings = self._decoder_pos_embeddings(label_lengths, tgt_mask) + assert torch.allclose(position_embeddings[~tgt_mask], tgt_embeddings[~tgt_mask]) + + tgt_embeddings = tgt_embeddings + position_embeddings + + # TODOPK remove layernorm & dropout (for inference) + # tgt_embeddings = self._decoder_layernorm( + # tgt_embeddings + # ) # (batch_size, dec_seq_len, embedding_dim) + # tgt_embeddings = self._decoder_dropout( + # tgt_embeddings + # ) # (batch_size, dec_seq_len, embedding_dim) + + tgt_embeddings[~tgt_mask] = 0 + + causal_mask = ( + torch.tril(torch.ones(label_len, label_len)).bool().to(DEVICE) + ) # (dec_seq_len, dec_seq_len) + + decoder_outputs = self._decoder( + tgt=tgt_embeddings, + memory=encoder_embeddings, + tgt_mask=~causal_mask, + memory_key_padding_mask=~encoder_mask, + ) # (batch_size, dec_seq_len, embedding_dim) + + return decoder_outputs + + def _decoder_pos_embeddings(self, lengths, mask): + def codebook_lambda(x): + non_bos = x < len(self._codebook_sizes) + x[non_bos] = (len(self._codebook_sizes) - 1) - x[non_bos] + return x # 3, 0, 1, 2, 3, 0, 1, 2 ... len(self._codebook_sizes) = 3 for bos + + codebook_embeddings = self._get_position_embeddings( + lengths, mask, codebook_lambda, self._codebook_embeddings + ) + + return codebook_embeddings + + def _apply_decoder_autoregressive(self, encoder_embeddings, encoder_mask): + batch_size = encoder_embeddings.shape[0] + embedding_dim = encoder_embeddings.shape[2] + + tgt_embeddings = ( + self._bos_weight.unsqueeze(0) + .unsqueeze(0) + .expand(batch_size, 1, embedding_dim) + ) + + semantic_ids = torch.tensor([], device=DEVICE, dtype=torch.int64) + + for step in range(len(self._codebook_sizes) + 1): # semantic_id_seq + residual + index = len(self._codebook_sizes) if step == 0 else step - 1 + + last_position_embedding = self._codebook_embeddings( + torch.full((batch_size,), index, device=DEVICE) + ) + + assert last_position_embedding.shape == tgt_embeddings[:, -1, :].shape + assert tgt_embeddings.shape == torch.Size([batch_size, step + 1, embedding_dim]) + + curr_step_embeddings = tgt_embeddings.clone() + curr_step_embeddings[:, -1, :] = ( + tgt_embeddings[:, -1, :] + last_position_embedding + ) + assert torch.allclose(tgt_embeddings[:, :-1, :], curr_step_embeddings[:, :-1, :]) + tgt_embeddings = curr_step_embeddings + + # curr_embeddings[:, -1, :] = self._decoder_layernorm(curr_embeddings[:, -1, :]) + # curr_embeddings[:, -1, :] = self._decoder_dropout(curr_embeddings[:, -1, :]) + + causal_mask = ( + torch.tril(torch.ones(step + 1, step + 1)).bool().to(DEVICE) + ) # (dec_seq_len, dec_seq_len) + + decoder_output = self._decoder( + tgt=tgt_embeddings, + memory=encoder_embeddings, + tgt_mask=~causal_mask, + memory_key_padding_mask=~encoder_mask, + ) + + # TODOPK add assert for all except last layer (check if only last layer changes) + # TODOPK check decoder output for several outputs + # TODOPK ASK it is not true? + # assert that prelast items don't change + # assert decoder changes only last index in dim = 1 + + next_token_embedding = decoder_output[ + :, -1, : + ] # batch_size x embedding_dim + + if step < len(self._codebook_sizes): + codebook = self._codebook_item_embeddings_stacked[ + step + ] # codebook_size x embedding_dim + closest_semantic_ids = torch.argmax( + torch.einsum("bd,cd->bc", next_token_embedding, codebook), dim=1 + ) # batch_size + semantic_ids = torch.cat( + [semantic_ids, closest_semantic_ids.unsqueeze(1)], dim=1 + ) # batch_size x (step + 1) + next_token_embedding = codebook[ + closest_semantic_ids + ] # batch_size x embedding_dim + + tgt_embeddings = torch.cat( + [tgt_embeddings, next_token_embedding.unsqueeze(1)], dim=1 + ) + + return semantic_ids, tgt_embeddings + + def get_item_embeddings(self, events): + embs = self._item_id_to_semantic_embedding[ + events - 1 + ] # len(events), len(self._codebook_sizes) + 1, embedding_dim + return embs.reshape(-1, self._embedding_dim) + + def get_init_item_embeddings(self, item_id_to_semantic_id, item_id_to_residual): + result = [] + for semantic_id in item_id_to_semantic_id: + item_repr = [] + for codebook_idx, codebook_id in enumerate(semantic_id): + item_repr.append( + self._codebook_item_embeddings_stacked[codebook_idx][codebook_id] + ) + result.append(torch.stack(item_repr)) + + semantic_embeddings = torch.stack( + result + ) # len(events), len(codebook_sizes), embedding_dim + + residual = item_id_to_residual.unsqueeze(1) + + # get true item embeddings + item_embeddings = torch.cat( + [semantic_embeddings, residual], dim=1 + ) # len(events), len(self._codebook_sizes) + 1, embedding_dim + + return item_embeddings + + def _encoder_pos_embeddings(self, lengths, mask): + def position_lambda(x): + return x // ( + len(self._codebook_sizes) + 1 + ) # 5 5 5 5 4 4 4 4 ..., +1 for residual + + position_embeddings = self._get_position_embeddings( + lengths, mask, position_lambda, self._position_embeddings + ) + + def codebook_lambda(x): + x = len(self._codebook_sizes) - x % (len(self._codebook_sizes) + 1) + x[x == len(self._codebook_sizes)] = len(self._codebook_sizes) + 1 + # 0 1 2 4 0 1 2 4 ... # len(self._codebook_sizes) + 1 = 4 for residual + return x + + codebook_embeddings = self._get_position_embeddings( + lengths, mask, codebook_lambda, self._codebook_embeddings + ) + + return position_embeddings + codebook_embeddings + + def _get_position_embeddings(self, lengths, mask, position_lambda, embedding_layer): + batch_size = mask.shape[0] + seq_len = mask.shape[1] + + positions = ( + torch.arange(start=seq_len - 1, end=-1, step=-1, device=DEVICE)[None] + .tile([batch_size, 1]) + .long() + ) # (batch_size, seq_len) + positions_mask = positions < lengths[:, None] # (batch_size, max_seq_len) + + positions = positions[positions_mask] # (all_batch_events) + # 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 7 6 5 4 3 2 1 0 ... + + positions = position_lambda(positions) # (all_batch_events) + + # print(f"{positions.tolist()[:20]=}") + + assert (positions >= 0).all() and ( + positions < embedding_layer.num_embeddings + ).all() + + position_embeddings = embedding_layer( + positions + ) # (all_batch_events, embedding_dim) + + position_embeddings, _ = create_masked_tensor( + data=position_embeddings, lengths=lengths + ) # (batch_size, seq_len, embedding_dim) + + return position_embeddings diff --git a/modeling/optimizer/base.py b/modeling/optimizer/base.py index ede62be2..86c63537 100644 --- a/modeling/optimizer/base.py +++ b/modeling/optimizer/base.py @@ -1,8 +1,7 @@ import copy -from utils import MetaParent - import torch +from utils import MetaParent OPTIMIZERS = { 'sgd': torch.optim.SGD, diff --git a/modeling/rqvae_utils/__init__.py b/modeling/rqvae_utils/__init__.py new file mode 100644 index 00000000..2722515a --- /dev/null +++ b/modeling/rqvae_utils/__init__.py @@ -0,0 +1,4 @@ +from .collision_solver import CollisionSolver +from .trie import Trie +from .tree import Tree +from .simplified_tree import SimplifiedTree \ No newline at end of file diff --git a/modeling/rqvae_utils/collision_solver.py b/modeling/rqvae_utils/collision_solver.py new file mode 100644 index 00000000..f673c6dd --- /dev/null +++ b/modeling/rqvae_utils/collision_solver.py @@ -0,0 +1,179 @@ +from collections import defaultdict + +import torch + +from utils import DEVICE + + +class CollisionSolver: + def __init__(self, + emb_dim: int, + sem_id_len: int, + codebook_size: int, + device: torch.device=DEVICE): + """ + :param emb_dim: Длина остатка + :param codebook_size: Количество элементов в одном кодбуке + :param sem_id_len: Длина semantic_id (без токена решающего коллизии) + :param device: Устройство + """ + self._sem_ids_sparse_tensor: torch.Tensor = torch.empty((0, 0)) # тензор группирирующий остатки по semantic_id + self.item_ids_sparse_tensor: torch.Tensor = torch.empty( + (0, 0)) # тензор группирирующий реальные айди айтемов по semantic_id + self.counts_dict: dict[int, int] = defaultdict(int) # тензор храняющий количество коллизий по semantic_id + self.emb_dim: int = emb_dim # длина остатка + self.sem_id_len: int = sem_id_len # длина semantic_id + self.codebook_size: int = codebook_size # количество элементов в одном кодбуке + self.device: torch.device = device # девайс + + self.key: torch.Tensor = torch.tensor([self.codebook_size ** i for i in range(self.sem_id_len)], + dtype=torch.long, + device=self.device) # ключ для сопоставления числа каждому semantic_id + + def create_query_candidates_dict(self, item_ids: torch.Tensor, semantic_ids: torch.Tensor, + residuals: torch.Tensor) -> None: + """ + Создает разреженный тензор, который содержит сгруппированные по semantic id элементы + + :param item_ids: Реальные айди айтемов (пусть будут больше 0) (count,) + :param semantic_ids: Тензор всех semantic_id, полученных из rq-vae (без токенов решающих коллизии) (count, sem_id_len) + :param residuals: Тензор остатков для каждого semantic_id (count, emb_dim) + """ + residuals_count, residual_length = residuals.shape + semantic_ids_count, semantic_id_length = semantic_ids.shape + + assert residuals_count == semantic_ids_count + assert semantic_id_length == self.sem_id_len + assert residual_length == self.emb_dim + assert item_ids.shape == (residuals_count,) + + item_ids = item_ids.to(self.device) + residuals = residuals.to(self.device) + semantic_ids = semantic_ids.to(self.device) + + unique_id = (semantic_ids * self.key).sum(dim=1) # хэши + unique_ids, inverse_indices, counts = torch.unique(unique_id, return_inverse=True, return_counts=True) + sorted_indices = torch.argsort(inverse_indices) # сортированные индексы чтобы совпадающие хэши шли подряд + + row_indices = inverse_indices[sorted_indices] # отсортированные хэши + + offsets = torch.cumsum(counts, dim=0) - counts + col_indices = torch.arange(semantic_ids_count, device=self.device) - offsets[ + row_indices] # индексы от 0 до k внутри каждого набора из совпадающих хэшей + + indices = torch.stack([ + unique_ids[row_indices], + col_indices + ], + dim=0) # индексы для разреженного тензора: 1 размерность хэш, 2 размерность индексы от 0 до k для коллизий каждого хэша + + max_residuals_count = int(counts.max().item()) # максимальное количество коллизий для одного sem_id + max_sid = int(self.codebook_size ** self.sem_id_len) # максимальный хэш sem_id который может быть + + self._sem_ids_sparse_tensor = torch.sparse_coo_tensor(indices, residuals[sorted_indices], + size=(max_sid, max_residuals_count, self.emb_dim), + device=self.device) # (max_sid, max_residuals_count, emb_dim) + + self.counts_dict = defaultdict(int, zip(unique_ids.tolist(), counts.tolist())) # sid -> collision count + + self.item_ids_sparse_tensor = torch.sparse_coo_tensor(indices, item_ids[sorted_indices], + size=(max_sid, max_residuals_count), device=self.device, + dtype=torch.int32) # (max_sid, max_residuals_count) + + def get_residuals_by_semantic_id_batch(self, semantic_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + :param semantic_ids батч из semantic ids (batch_size, sem_id_len) + + :return: + Возвращает тензор эмбеддингов для батча semantic_ids, размерность (batch_size, max_residuals_count, emb_dim) + Возвращает маску для этого тензора, размерность (batch_size, max_residuals_count, emb_dim) + """ + assert semantic_ids.shape[1] == self.sem_id_len + + semantic_ids = semantic_ids.to(self.device) + unique_ids = (semantic_ids * self.key).sum(dim=1) + + candidates = torch.stack([self._sem_ids_sparse_tensor[key].to_dense() for key in unique_ids]) + counts = torch.tensor([self.counts_dict[key.item()] for key in unique_ids], device=self.device) + mask = torch.arange(candidates.shape[1], device=self.device).expand(len(unique_ids), -1) < counts.view(-1, 1) + + return candidates, mask + + def get_pred_scores(self, semantic_ids: torch.Tensor, pred_residuals: torch.Tensor) -> dict[str, torch.Tensor]: + """ + :param semantic_id: [batch_size, sem_id_len] semantic ids (без токена решающего коллизии) + :param pred_residuals: [batch_size, emb_dim] предсказанные остатки + + :return: Словарь с ключами: + - 'pred_scores_mask': [batch_size, max_collision_count] маска существующих значений scores для предсказанных остатков + - 'pred_scores': [batch_size, max_collision_count] софтмакс для каждого из кандидатов для предсказанных остатков + - 'pred_item_ids': [batch_size] реальные айди айтемов для предсказанных остатков + """ + assert semantic_ids.shape[1] == self.sem_id_len + assert pred_residuals.shape[1] == self.emb_dim + assert semantic_ids.shape[0] == pred_residuals.shape[0] + + semantic_ids = semantic_ids.to(self.device) + pred_residuals = pred_residuals.to(self.device) + + unique_ids = (semantic_ids * self.key).sum(dim=1) + + candidates, mask = self.get_residuals_by_semantic_id_batch(semantic_ids) + + pred_scores = torch.einsum('njk,nk->nj', candidates, pred_residuals).masked_fill(~mask, -torch.inf) + pred_indices = torch.argmax(pred_scores, dim=1) + pred_item_ids = torch.stack( + [self.item_ids_sparse_tensor[unique_ids[i]][pred_indices[i]] for i in range(semantic_ids.shape[0])]) + + return { + "pred_scores_mask": mask, + "pred_scores": pred_scores, + "pred_item_ids": pred_item_ids + } + + def get_true_dedup_tokens(self, semantic_ids: torch.Tensor, true_residuals: torch.Tensor) -> dict[ + str, torch.Tensor]: + """ + :param semantic_id: [batch_size, sem_id_len] semantic ids (без токена решающего коллизии) + :param true_residuals: [batch_size, emb_dim] реальные остатки + + :return: Словарь с ключами: + - 'true_dedup_tokens': [batch_size] токены решающие коллизии для реальных остатков + """ + assert semantic_ids.shape[1] == self.sem_id_len + assert true_residuals.shape[1] == self.emb_dim + assert semantic_ids.shape[0] == true_residuals.shape[0] + + semantic_ids = semantic_ids.to(self.device) + true_residuals = true_residuals.to(self.device) + + candidates, _ = self.get_residuals_by_semantic_id_batch(semantic_ids) + + matches = torch.all(candidates == true_residuals[:, None, :], dim=2).int() + true_dedup_tokens = torch.argmax(matches, dim=1) + + assert matches.any(dim=1).all(), "Не у всех батчей есть совпадение" + + return { + "true_dedup_tokens": true_dedup_tokens + } + + def get_item_ids_batch(self, semantic_ids: torch.Tensor, dedup_tokens: torch.Tensor) -> torch.Tensor: + """ + :param semantic_id: [batch_size, sem_id_len] semantic ids (без токенов решающего коллизии) + :param dedup_tokens: [batch_size] токены решающие коллизии + + :return: item_ids : [batch_size] реальные айди айтемов + """ + assert semantic_ids.shape[1] == self.sem_id_len + assert dedup_tokens.shape == (semantic_ids.shape[0],) + + semantic_ids = semantic_ids.to(self.device) + dedup_tokens = dedup_tokens.to(self.device) + + unique_ids = (semantic_ids * self.key).sum(dim=1) + + item_ids = torch.stack( + [self.item_ids_sparse_tensor[unique_ids[i]][dedup_tokens[i]] for i in range(semantic_ids.shape[0])]) + + return item_ids diff --git a/modeling/rqvae_utils/rqvae_data.py b/modeling/rqvae_utils/rqvae_data.py new file mode 100644 index 00000000..47692cae --- /dev/null +++ b/modeling/rqvae_utils/rqvae_data.py @@ -0,0 +1,96 @@ +import pandas as pd +import json +import gzip +from transformers import AutoTokenizer, AutoModelForSeq2SeqLM +import torch +import random + +from tqdm import tqdm + +tqdm.pandas() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model_name = "google-t5/t5-small" + +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) + + +def parse(path): + g = gzip.open(path, "rb") + for line in g: + yield eval(line) + + +def getDF(path): + i = 0 + df = {} + for d in parse(path): + df[i] = d + i += 1 + return pd.DataFrame.from_dict(df, orient="index") + + +def encode_text(text): + enc = tokenizer(text, return_tensors="pt", truncation=True).to(device) + + output = model.encoder( + input_ids=enc["input_ids"], + attention_mask=enc["attention_mask"], + return_dict=True, + ) + + embeddings = output.last_hidden_state.mean( + dim=1 + ).squeeze() # mean over all tokens (mb CLS?) + + return embeddings.cpu().detach() + + +def preprocess(row: pd.Series): + row = row.fillna("unknown") # empty? + # remove column description / title / cat? + return f"Description: {row['description']}. Title: {row['title']}. Categories: {', '.join(row['categories'][0])}" + + +def get_data(cached=True): + if not cached: + df = getDF("../data/meta_Beauty.json.gz") + + file_name = "../data/reviews_Beauty_5.json" + + unique_items = set() + unique_users = set() + + with open(file_name, "r") as file: + for line in file: + review = json.loads(line.strip()) + unique_items.add(review["asin"]) + unique_users.add(review["reviewerID"]) + + df = df[df["asin"].isin(unique_items)] + + df["combined_text"] = df.apply(preprocess, axis=1) + + with torch.no_grad(): + df["embeddings"] = df["combined_text"].progress_apply(encode_text) + else: + df = torch.load("../data/Beauty/data_full.pt", weights_only=False) + + return df + + +def search_similar_items(items_with_tuples, clust2search, max_cnt=5): + random.shuffle(items_with_tuples) + cnt = 0 + similars = [] + for asin, item, clust_tuple in items_with_tuples: + if clust_tuple[: len(clust2search)] == clust2search: + similars.append((asin, item, clust_tuple)) + cnt += 1 + if cnt >= max_cnt: + return similars + return similars + + + diff --git a/modeling/rqvae_utils/rqvae_test.py b/modeling/rqvae_utils/rqvae_test.py new file mode 100644 index 00000000..611e7739 --- /dev/null +++ b/modeling/rqvae_utils/rqvae_test.py @@ -0,0 +1,51 @@ +import json + +import numpy as np +import torch + +from models import RqVaeModel +from utils import DEVICE + +def test(a, b): + cos_sim = torch.nn.functional.cosine_similarity(a, b, dim=0) + norm_a = torch.norm(a, p=2) + norm_b = torch.norm(b, p=2) + l2_dist = torch.norm(a - b, p=2) / (norm_a + norm_b + 1e-8) + return cos_sim, l2_dist + +if __name__ == "__main__": + config = json.load(open("../configs/train/tiger_train_config.json")) + config = config["model"] + rqvae_config = json.load(open(config["rqvae_train_config_path"])) + rqvae_config["model"]["should_init_codebooks"] = False + rqvae_model = RqVaeModel.create_from_config(rqvae_config["model"]).to(DEVICE) + rqvae_model.load_state_dict( + torch.load(config["rqvae_checkpoint_path"], weights_only=True) + ) + df = torch.load(config["embs_extractor_path"], weights_only=False) + embeddings_array = np.stack(df["embeddings"].values) + tensor_embeddings = torch.tensor(embeddings_array, dtype=torch.float32, device=DEVICE) + inputs = {'embeddings': tensor_embeddings} + + rqvae_model.eval() + sem_ids, residuals = rqvae_model.forward(inputs) + scores = residuals.detach() + print(torch.norm(residuals, p=2, dim=1).median()) + for (i, codebook) in enumerate(rqvae_model.codebooks): + scores += codebook[sem_ids[:, i]].detach() + decoder_output = rqvae_model.decoder(scores.detach()).detach() + + a = tensor_embeddings[0] + b = decoder_output[0] + cos_sim, l2_dist = test(a, b) + print("косинусное расстояние", cos_sim) + print("евклидово расстояние", l2_dist) + + cos_sim = torch.nn.functional.cosine_similarity(tensor_embeddings, decoder_output, dim=1) + print("косинусное расстояние", cos_sim.mean(), cos_sim.min(), cos_sim.max()) + + norm_a = torch.norm(tensor_embeddings, p=2, dim = 1) + norm_b = torch.norm(decoder_output, p=2, dim = 1) + l2_dist = torch.norm(decoder_output - tensor_embeddings, p=2, dim = 1) / (norm_a + norm_b + 1e-8) + print("евклидово расстояние",l2_dist.median(), l2_dist.min(), l2_dist.max()) + diff --git a/modeling/rqvae_utils/simplified_tree.py b/modeling/rqvae_utils/simplified_tree.py new file mode 100644 index 00000000..5a09dcb4 --- /dev/null +++ b/modeling/rqvae_utils/simplified_tree.py @@ -0,0 +1,101 @@ +import torch + +from utils import DEVICE + + +class SimplifiedTree: + def __init__(self, embedding_table: torch.Tensor, device: torch.device = DEVICE): + """ + :param embedding_table: обученные эмбеддинги + :param device: устройство + """ + self.device: torch.device = device + self.embedding_table: torch.Tensor = embedding_table # (semantic_id_len, codebook_size, emb_dim) + self.sem_id_len, self.codebook_size, self.emb_dim = self.embedding_table.shape + self.sem_ids_count: int = 0 + self.full_embeddings: torch.Tensor = torch.empty((0, 0)) + self.item_ids: torch.Tensor = torch.empty((0, 0)) + + def build_tree_structure(self, semantic_ids: torch.Tensor, residuals: torch.Tensor, item_ids: torch.Tensor, + sum_with_residuals: bool = True) -> None: + """ + :param sum_with_residuals: флаг, отвечающий за то учитывать ли остатки при выборе кандидатов + :param semantic_ids: (sem_ids_count, sem_id_len) + :param residuals: (sem_ids_count, emb_dim) + :param item_ids: (sem_ids_count,) + """ + self.sem_ids_count = semantic_ids.shape[0] + assert residuals.shape == (self.sem_ids_count, self.emb_dim) + assert semantic_ids.shape == (self.sem_ids_count, self.sem_id_len) + assert item_ids.shape == (self.sem_ids_count,) + + semantic_ids = semantic_ids.to(self.device) + residuals = residuals.to(self.device).float() if sum_with_residuals else torch.zeros_like(residuals, + device=self.device, + dtype=torch.float) + self.full_embeddings = self.calculate_full(semantic_ids).float() + residuals + self.item_ids = item_ids + + def calculate_full(self, sem_ids: torch.Tensor) -> torch.Tensor: + """ + :param sem_ids: набор из sem ids (count, sem_id_len) + :return: эмбеддинг для каждого sem_id из набора (count, emb_dim) + """ + assert sem_ids.shape[1] == self.sem_id_len + sem_ids = sem_ids.to(self.device) + + expanded_emb_table = (self.embedding_table.unsqueeze(0) + .expand(sem_ids.shape[0], -1, -1, -1)) # (count, sem_id_len, codebook_size, emb_dim) + + index = (sem_ids.unsqueeze(-1) + .expand(-1, -1, self.emb_dim) + .unsqueeze(2)) # (count, sem_id_len, 1, emb_dim) + + return torch.gather(input=expanded_emb_table, index=index, dim=2).sum(1).squeeze(1) # (count, emb_dim) + + def query(self, request_sem_ids: torch.Tensor, items_to_query: int) -> torch.Tensor: + """ + :param request_sem_ids: батч sem ids (batch_size, sem_id_len) + :param items_to_query: количество ближайших элементов которые нужно взять (int) + :return: тензор индексов ближайших k элементов из всех semantic_ids для каждого sem_id из батча (batch_size, k) + """ + assert request_sem_ids.shape[1] == self.sem_id_len + assert 0 < items_to_query <= self.sem_ids_count + + request_sem_ids = request_sem_ids.to(self.device) + request_embeddings = self.calculate_full(request_sem_ids) # (batch_size, emb_dim) + + request_embeddings = (request_embeddings.unsqueeze(1) + .expand(-1, self.sem_ids_count, -1)) # (batch_size, sem_ids_count, emb_dim) + + diff_norm = torch.norm(self.full_embeddings - request_embeddings, p=2, dim=2) # (batch_size, sem_ids_count) + + indices = torch.argsort(diff_norm, descending=False, dim=1)[:, :items_to_query] # (batch_size, k) + return self.item_ids[indices] + + def _query(self, request_sem_ids: torch.Tensor, k: int) -> torch.Tensor: + """ + Альтернатива get_ids, попытка ускорить + :param request_sem_ids: батч sem ids (batch_size, sem_id_len) + :param k: количество ближайших элементов которые нужно взять (int) + :return: тензор индексов ближайших k элементов из всех semantic_ids для каждого sem_id из батча (batch_size, k) + """ + assert request_sem_ids.shape[1] == self.sem_id_len + assert 0 < k <= self.sem_ids_count + request_sem_ids = request_sem_ids.to(self.device) + + index = (request_sem_ids.unsqueeze(-1) + .expand(-1, -1, self.emb_dim) + .unsqueeze(2)) # (batch_size, sem_id_len, 1, emb_dim) + + request_embeddings = torch.gather( + input=self.embedding_table.unsqueeze(0).expand(request_sem_ids.shape[0], -1, -1, -1), + dim=2, + index=index + ).sum(1) # (batch_size, emb_dim) + + diff_norm = torch.cdist(self.full_embeddings, request_embeddings.unsqueeze(1), p=2).squeeze( + 1) # (batch_size, sem_ids_count) + + _, indices = torch.topk(diff_norm, k=k, dim=1, largest=False) # (batch_size, k) + return self.item_ids[indices.squeeze(-1)] diff --git a/modeling/rqvae_utils/tree.py b/modeling/rqvae_utils/tree.py new file mode 100644 index 00000000..b09cd049 --- /dev/null +++ b/modeling/rqvae_utils/tree.py @@ -0,0 +1,209 @@ +import numpy as np +import torch + +from utils import DEVICE + + +class Tree: + def __init__(self, embedding_table: torch.Tensor, device: torch.device = DEVICE): + """ + :param embedding_table: обученные эмбеддинги + :param device: устройство + """ + self.device: torch.device = device + self.embedding_table: torch.Tensor = embedding_table # (semantic_id_len, codebook_size, emb_dim) + self.sem_id_len, self.codebook_size, self.emb_dim = self.embedding_table.shape + self.key: torch.Tensor = torch.empty((0, 0)) + self.A: torch.Tensor = torch.empty((0, 0)) # будет (max_sem_id, ) + self.sem_ids_count: int = -1 + self.sem_ids_embs: torch.Tensor = torch.empty((0, 0)) + self.sids: torch.Tensor = torch.empty((0, 0)) # будет (sem_id_len, ) + self.item_ids: torch.Tensor = torch.empty((0, 0)) + + def build_tree_structure(self, semantic_ids: torch.Tensor, residuals: torch.Tensor, item_ids: torch.Tensor): + """ + :param semantic_ids: (sem_ids_count, sem_id_len) + :param residuals: (sem_ids_count, emb_dim) + :param item_ids: (sem_ids_count,) + """ + self.sem_ids_count = semantic_ids.shape[0] + + assert semantic_ids.shape[0] == residuals.shape[0] + assert semantic_ids.shape[1] == self.sem_id_len + assert residuals.shape[1] == self.emb_dim + assert item_ids.shape == (self.sem_ids_count,) + + self.item_ids = item_ids + self.key = torch.tensor([self.codebook_size ** i for i in range(self.sem_id_len - 1, -1, -1)], + dtype=torch.long, device=self.device) + self.sids = self.get_sids(semantic_ids.float()) # (sem_id_len, ) + self.sem_ids_embs = self.calculate_full(semantic_ids, residuals) + + result = torch.full(size=[self.codebook_size ** self.sem_id_len], fill_value=0, dtype=torch.int64, + device=self.device) + temp_unique_id = self.sids * self.codebook_size + temp_sem_ids = torch.concat([semantic_ids, torch.zeros(self.sem_ids_count, device=self.device).unsqueeze(1)], + dim=-1) + + for i in range(0, self.sem_id_len + 1): + temp_unique_id = temp_unique_id - (self.codebook_size ** i) * temp_sem_ids[:, self.sem_id_len - i] + temp_unique_ids, temp_inverse_indices = torch.unique(temp_unique_id, return_inverse=True) + temp_counts = torch.bincount(temp_inverse_indices) + truncated_ids = torch.floor_divide(input=temp_unique_id, other=(self.codebook_size ** (i + 1))).long() + result[truncated_ids] = temp_counts[temp_inverse_indices] + + self.A = result + + def get_counts(self, sem_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + :param sem_ids: (batch_size, sem_id_len) + :return: префиксы всех длин sem_ids, количество sem_id на каждой глубине дерева + """ + assert sem_ids.shape[1] == self.sem_id_len + + offsets = torch.arange(self.sem_id_len + 1, device=self.device) + i = torch.arange(self.sem_id_len, device=self.device) + + mask_sem = (i < (self.sem_id_len - offsets.unsqueeze(1))).long() # (sem_id_len + 1, sem_id_len) + divs = torch.pow(self.codebook_size, offsets) # (sem_id_len + 1,) + + C = (sem_ids.unsqueeze(1) * mask_sem.unsqueeze(0) * self.key.unsqueeze(0).unsqueeze(1)).sum(dim=-1) + B = C // divs.unsqueeze(0) + + return C, self.A[B] # (batch_size, sem_id_len + 1), (batch_size, sem_id_len + 1) + + def get_sids(self, sem_ids: torch.Tensor) -> torch.Tensor: + """ + :param sem_ids: (sem_id_count, sem_id_len) + :return: хэши sem_ids (sem_id_count,) + """ + assert sem_ids.shape[1] == self.sem_id_len + return torch.einsum('nc,c->n', sem_ids, self.key.float()) # (sem_ids_count,) + + def calc_ol(self, batch_ids: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + :param batch_ids: (batch_size, sem_id_len) + :param k: int + :return: тензор глубин на которые нужно подняться (batch_size,), маска для sem_id для нужной глубины (batch_size, sem_ids_count) + """ + assert batch_ids.shape[1] == self.sem_id_len + assert k < self.sem_ids_count # корректный сценарий когда тензор не пустой + + c, a = self.get_counts(batch_ids) + ol = torch.argmax((a > k).long(), dim=-1) # (bs,) + gather_ol = torch.gather(c, dim=1, index=ol.unsqueeze(1)).squeeze() # (bs,) + + mask_ol = (gather_ol.unsqueeze(-1) <= self.sids) & ( + self.sids < (gather_ol + torch.pow(self.codebook_size, ol)).unsqueeze(-1)) + return ol, mask_ol # (bs,) (bs, sem_ids_count) + + def calc_il(self, batch_ids, k): + """ + :param batch_ids: (batch_size, sem_id_len) + :param k: int + :return: тензор глубин на которые нужно подняться (batch_size,), маска для sem_id для нужной глубины (batch_size, sem_ids_count) + """ + assert batch_ids.shape[1] == self.sem_id_len + assert k < self.sem_ids_count # корректный сценарий когда тензор не пустой + + batch_dim = batch_ids.shape[0] + c, a = self.get_counts(batch_ids) + extended_c = torch.concat([torch.tensor(float("inf"), device=self.device).expand(batch_dim, 1), c], dim=1) + + il = torch.argmax((a > k).long(), dim=-1) - 1 # (bs,) + gather_il = torch.gather(extended_c, dim=1, index=(il + 1).unsqueeze(1)).squeeze() # (bs,) + + mask_il = (gather_il.unsqueeze(-1) <= self.sids) & ( + self.sids < (gather_il + torch.pow(self.codebook_size, il)).unsqueeze(-1)) + return il, mask_il # (bs,) (bs, sem_ids_count) + + def get_repeated_sids(self, k: int) -> torch.Tensor: + return self.sids.repeat(k, 1) # (k, sem_ids_count) + + def get_request_embeddings(self, decomposed_embeddings: torch.Tensor, levels: torch.Tensor) -> torch.Tensor: + """ + :param decomposed_embeddings: разложение sem_ids на эмбеддинги (count, sem_id_len +1, emb_dim) + :param levels: сколько нужно взять эмбеддингов для суммы для каждого sem_id (count,) + :return: эмбеддинги sem_id для нужных глубин (count, emb_dim) + """ + assert decomposed_embeddings.shape[1:] == (self.sem_id_len + 1, self.emb_dim) + assert levels.shape == (decomposed_embeddings.shape[0],) + + mask = torch.arange(1, self.sem_id_len + 2, device=self.device) >= torch.arange(self.sem_id_len + 2, 0, -1, + device=self.device).unsqueeze(1) + return torch.sum(decomposed_embeddings * mask[levels + 1].unsqueeze(-1), dim=1) # (bs, emb_dim) + + def calculate_full(self, sem_ids: torch.Tensor, residuals: torch.Tensor) -> torch.Tensor: + """ + :param sem_ids: sem_ids (count, sem_id_len) + :param residuals: остатки для каждого sem_id (count, emb_dim) + :return: полные эмбеддинги для каждого айтема (count, emb_dim) + """ + assert sem_ids.shape[1] == self.sem_id_len + assert residuals.shape[1] == self.emb_dim + assert residuals.shape[0] == sem_ids.shape[0] + + count = residuals.shape[0] + index = sem_ids.view(count, -1, 1, 1).expand(-1, -1, -1, self.emb_dim) + embs = torch.gather(input=self.embedding_table.unsqueeze(0).expand(count, -1, -1, -1), dim=2, + index=index) # expand бесплатный по памяти + decomposed_embs = torch.concat([embs.squeeze(2), residuals.unsqueeze(1)], dim=1) # (sem_ids_count, emb_dim) + + assert decomposed_embs.shape == (sem_ids.shape[0], self.sem_id_len + 1, self.emb_dim) + return decomposed_embs + + def calculate_level_embeddings(self, decomposed_embeddings: torch.Tensor, levels: torch.Tensor) -> torch.Tensor: + """ + :param decomposed_embeddings: разложение sem_ids на эмбеддинги (count, sem_id_len +1, emb_dim) + :param levels: сколько нужно взять эмбеддингов для суммы для каждого sem_id (count,) + :return: эмбеддинги для всех sem_ids для нужных глубин (batch_size, sem_ids_count, emb_dim) + """ + assert decomposed_embeddings.shape == (self.sem_ids_count, self.sem_id_len + 1, self.emb_dim) + + mask = (torch.arange(1, self.sem_id_len + 2, device=self.device) >= + torch.arange(self.sem_id_len + 2, 0, -1, device=self.device).unsqueeze(1)).float() + sids_mask = mask[levels + 1].unsqueeze(-1) # (batch_size, sem_id_len + 1, 1) + return torch.einsum('nld,bld->bnd', decomposed_embeddings, sids_mask) # (batch_size, sem_ids_count, emb_dim) + + def mask_result(self, result: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + return torch.where(mask, result, torch.tensor(float('-inf'), device=self.device)) + + def query(self, request_sem_ids: torch.Tensor, request_residuals: torch.Tensor, + items_to_query: int) -> torch.Tensor: + """ + :param request_sem_ids: батч из sem_ids (batch_size, sem_id_len) + :param request_residuals: батч из остатков (batch_size, emb_dim) + :param items_to_query: количество ближайших элементов которые нужно взять int + :return: тензор индексов ближайших k элементов из всех semantic_ids для каждого sem_id из батча (batch_size, k) + """ + assert request_sem_ids.shape[0] == request_residuals.shape[0] + assert request_sem_ids.shape[1] == self.sem_id_len + assert request_residuals.shape[1] == self.emb_dim + assert 0 <= items_to_query < self.sem_ids_count + + ol, ol_mask = self.calc_ol(request_sem_ids, items_to_query) + il, il_mask = self.calc_il(request_sem_ids, items_to_query) + + il_mask = il_mask.detach().cpu() + ol_mask = ol_mask.detach().cpu() + + ol_mask = ol_mask & ~il_mask + + request_embs = self.calculate_full(request_sem_ids, request_residuals) + + ol_sids_embeddings = self.calculate_level_embeddings(self.sem_ids_embs, ol) + il_sids_embeddings = self.calculate_level_embeddings(self.sem_ids_embs, il) + + ol_request_embeddings = self.get_request_embeddings(request_embs, ol) + il_request_embeddings = self.get_request_embeddings(request_embs, il) + + ol_scores = torch.matmul(ol_sids_embeddings, ol_request_embeddings.unsqueeze(-1)).squeeze(-1).detach().cpu() + + il_scores = torch.matmul(il_sids_embeddings, il_request_embeddings.unsqueeze(-1)).squeeze(-1).detach().cpu() + + ids = np.lexsort(keys=(-torch.cat([il_scores, ol_scores], dim=1), + ~torch.cat([torch.ones_like(il_mask), torch.zeros_like(ol_mask)], dim=1), + ~torch.cat([il_mask, ol_mask], dim=1))) + + ids = (ids % self.sem_ids_count)[:, :self.sem_ids_count][:, :items_to_query] # (batch_size, k) + return self.item_ids[ids] diff --git a/modeling/rqvae_utils/tree_comparing.py b/modeling/rqvae_utils/tree_comparing.py new file mode 100644 index 00000000..3f1d3f38 --- /dev/null +++ b/modeling/rqvae_utils/tree_comparing.py @@ -0,0 +1,128 @@ +import json +import os +import time + +import psutil +import torch + +from models.rqvae import RqVaeModel +from rqvae_utils import Trie, SimplifiedTree, Tree +from utils import DEVICE + + +def memory_stats(k): + process = psutil.Process(os.getpid()) + memory_usage = process.memory_info().rss / 1024 ** 2 + print(f"{k}. Использование памяти: {memory_usage:.2f} MB") + + +def calc_sid(sid, codebook_size): + res = sid[-1] + for i in range(1, sid.shape[0]): + res += sid[-i - 1] * (codebook_size ** i) + return res + + +def stats(query_sem_id, codebook_size, sids, item_ids): + for sem_id, ids in zip(query_sem_id.tolist(), item_ids.tolist()): + print(calc_sid(torch.tensor(sem_id), codebook_size)) + print(sids[torch.tensor(ids)][:10]) + + +if __name__ == "__main__": + embedding_dim = 64 # Embedding size + config = json.load(open("../configs/train/tiger_train_config.json")) + config = config["model"] + rqvae_config = json.load(open(config["rqvae_train_config_path"])) + rqvae_config["model"]["should_init_codebooks"] = False + rqvae_model = RqVaeModel.create_from_config(rqvae_config["model"]).to(DEVICE) + rqvae_model.load_state_dict( + torch.load(config["rqvae_checkpoint_path"], weights_only=True) + ) + rqvae_model.eval() + + emb_table = torch.stack( + [cb for cb in rqvae_model.codebooks] + ).to(DEVICE) + + trie = Trie(rqvae_model) + tree = Tree(rqvae_model, DEVICE) + simplified_tree = SimplifiedTree(rqvae_model, DEVICE) + simplified_tree_wr = SimplifiedTree(rqvae_model, DEVICE) + alphabet_size = 10 + + N = 12101 + K = 3 + + semantic_ids = torch.randint(0, alphabet_size, (N, K), dtype=torch.int64).to(DEVICE) + residuals = torch.randn(N, embedding_dim).to(DEVICE) + item_ids = torch.arange(5, N + 5).to(DEVICE) + print(residuals[0]) + + now = time.time() + trie.build_tree_structure(semantic_ids, residuals, item_ids) + print(f"Time for trie init: {(time.time() - now) * 1000:.2f} ms") + + now = time.time() + tree.build_tree_structure(semantic_ids, residuals, item_ids) + print(f"Time for tree init: {(time.time() - now) * 1000:.2f} ms") + + now = time.time() + simplified_tree.build_tree_structure(semantic_ids, residuals, item_ids) + print(f"Time for simplified tree init: {(time.time() - now) * 1000:.2f} ms") + + now = time.time() + simplified_tree_wr.build_tree_structure(semantic_ids, residuals, item_ids, False) + print(f"Time for simplified tree without residuals init: {(time.time() - now) * 1000:.2f} ms") + + full_embeddings = tree.calculate_full(semantic_ids, residuals).sum(1) + print(torch.all((full_embeddings == simplified_tree.full_embeddings) == True)) + + items_to_query = 20 + batch_size = 256 + q_semantic_ids = torch.randint(0, alphabet_size, (batch_size, K), dtype=torch.int64, device=DEVICE) + q_residuals = torch.randn(batch_size, embedding_dim).to(DEVICE) + + total_time = 0 + n_exps = 1 + + memory_stats(1) + for i in range(n_exps): + now = time.time() + item_ids = trie.query(q_semantic_ids, q_residuals, items_to_query) + total_time += time.time() - now + stats(q_semantic_ids[:1], 256, tree.sids, item_ids[:1]) + + print(f"Time per query: {total_time / n_exps * 1000:.2f} ms") + + memory_stats(2) + + for i in range(n_exps): + now = time.time() + simplified_tree_ids = simplified_tree.query(q_semantic_ids, items_to_query) + total_time += time.time() - now + stats(q_semantic_ids[:1], 256, tree.sids, simplified_tree_ids[:1]) + + print(f"Time per query: {total_time / n_exps * 1000:.2f} ms") + + memory_stats(3) + + for i in range(n_exps): + now = time.time() + simplified_tree_ids = simplified_tree_wr.query(q_semantic_ids, items_to_query) + total_time += time.time() - now + stats(q_semantic_ids[:1], 256, tree.sids, simplified_tree_ids[:1]) + + print(f"Time per query: {total_time / n_exps * 1000:.2f} ms") + + memory_stats(4) + + for i in range(n_exps): + now = time.time() + tree_ids = tree.query(q_semantic_ids, q_residuals, items_to_query) + total_time += time.time() - now + stats(q_semantic_ids[:1], 256, tree.sids, tree_ids[:1]) + + print(f"Time per query: {total_time / n_exps * 1000:.2f} ms") + + memory_stats(5) diff --git a/modeling/rqvae_utils/trie.py b/modeling/rqvae_utils/trie.py new file mode 100644 index 00000000..b7971e86 --- /dev/null +++ b/modeling/rqvae_utils/trie.py @@ -0,0 +1,377 @@ +import json +import time + +import torch +from models.rqvae import RqVaeModel +from utils import DEVICE + + +class Trie: + def __init__(self, rqvae_model: RqVaeModel): + self.rqvae_model = rqvae_model + self.keys = None + self.prefix_counts = None + self.residuals_per_level = None + self.raw_item_ids = None + self.K = len(self.rqvae_model.codebook_sizes) + self.total_items = None + self.embedding_table = torch.stack( + [cb for cb in self.rqvae_model.codebooks] + ) # K x codebook_size x embedding_dim + + def unique_with_index(self, x, dim=None): + """Unique elements of x and indices of those unique elements + https://github.com/pytorch/pytorch/issues/36748#issuecomment-619514810 + + e.g. + + unique(tensor([ + [1, 2, 3], + [1, 2, 4], + [1, 2, 3], + [1, 2, 5] + ]), dim=0) + => (tensor([[1, 2, 3], + [1, 2, 4], + [1, 2, 5]]), + tensor([0, 1, 3])) + """ + unique, inverse = torch.unique(x, sorted=True, return_inverse=True, dim=dim) + perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device) + inverse, perm = inverse.flip([0]), perm.flip([0]) + return unique, inverse.new_empty(unique.size(0)).scatter_(0, inverse, perm) + + def compute_keys(self, semantic_ids: torch.Tensor): + exponents = torch.arange(self.K - 1, -1, -1, device=DEVICE).float() + base = self.rqvae_model.codebook_sizes[0] ** exponents + uniq_ids = semantic_ids.float() @ base + return uniq_ids.int() + + def pad_semantic_ids(self, semantic_ids: torch.Tensor): + return torch.cat( + [ + semantic_ids, + torch.zeros( + semantic_ids.shape[0], + self.K - semantic_ids.shape[1], + dtype=semantic_ids.dtype, + device=semantic_ids.device, + ), + ], + dim=1, + ) + + def build_tree_structure( + self, + semantic_ids: torch.Tensor, + residuals: torch.Tensor, + raw_item_ids: torch.Tensor, + ): + """ + Order of semantic ids, residuals, raw_item_ids must be the same (corresponding to same item) + """ + bs = semantic_ids.shape[0] + + prefix_counts = torch.zeros(bs, self.K + 1, dtype=torch.int64) # bs x K+1 + prefix_counts[:, 0] = bs + + for i in range(self.K): + truncated_semantic_ids = semantic_ids[:, : i + 1] + padded_semantic_ids = self.pad_semantic_ids(truncated_semantic_ids) + prefix_keys = self.compute_keys( + padded_semantic_ids + ) # bs, semantic_ids order + unique_prefixes, inverse_indices_prefix_counts, prefix_counts_at_level = ( + torch.unique(prefix_keys, return_inverse=True, return_counts=True) + ) # [1 2 3 3 2] -> [1 2 3] [0 1 2 2 1] [1 2 2] + current_level_same = prefix_counts_at_level[ + inverse_indices_prefix_counts + ] # [1 2 2 2 2] + prefix_counts[:, i + 1] = current_level_same + + residuals_per_level = self.get_residuals_per_level( + semantic_ids, residuals + ) # total_items x K + 1 x embedding_dim + + keys = self.compute_keys(semantic_ids) # bs, could be collisions + + self.keys = keys + self.prefix_counts = prefix_counts + self.residuals_per_level = residuals_per_level + self.raw_item_ids = raw_item_ids + self.total_items = len(keys) + + def get_residuals_per_level( + self, + semantic_ids: torch.Tensor, + residuals: torch.Tensor, + ): + bs = semantic_ids.shape[0] + embedding_dim = residuals.shape[1] + residuals_per_level = torch.zeros( + bs, self.K + 1, embedding_dim, device=self.embedding_table.device + ) # bs x K + 1 x embedding_dim + + # TODOPK think if reverse is needed here + # i = 3, 2, 1, 0 + for i in range(self.K - 1, -1, -1): + indices_at_level = semantic_ids[:, i] # bs + embeddings_at_level = self.embedding_table[ + i, indices_at_level + ] # bs x embedding_dim + # 1 2 3 4 + residuals_per_level[:, self.K - i, :] = ( + embeddings_at_level + residuals_per_level[:, self.K - i - 1, :] + ) # [0 first_cumul_emb, second, ..., full_emb] + + # TODOPK check that residuals_per_level equal at last layer to full embedding of semantic id + + residuals_per_level[:, 0, :] = residuals + + return residuals_per_level # bs x K + 1 x embedding_dim + + def get_mask_by_prefix(self, prefixes: torch.Tensor, taken_lens: torch.Tensor): + bs = prefixes.shape[0] + padded_prefix = self.pad_semantic_ids(prefixes) + lower_key = self.compute_keys(padded_prefix) # bs + upper_key = lower_key + self.rqvae_model.codebook_sizes[0] ** ( + self.K - taken_lens + ) # bs + + # self.K = 4, prefix_len = 3 => 256 ^ 3 + 256 ^ 2 + 256 ^ 1 + 256 ^ 0 + # need to add 256 ^ 1 to get exclusive upper bound + # self.K = 4, prefix_len = 2 => 256 ^ 3 + 256 ^ 2 + 256 ^ 1 + 256 ^ 0 + # need to add 256 ^ 2 to get exclusive upper bound + # self.K = 4, prefix_len = 1 => 256 ^ 3 + 256 ^ 2 + 256 ^ 1 + 256 ^ 0 + # need to add 256 ^ 3 to get exclusive upper bound + # self.keys.shape = bs, lower_key.shape = bs, upper_key.shape = bs + + assert lower_key.shape[0] == upper_key.shape[0] == bs + assert self.keys.shape[0] == self.total_items + + mask = ( + ( + self.keys.unsqueeze(0) >= lower_key.unsqueeze(1) + ) # including prefix [1, 2, 0, 0] + & ( + self.keys.unsqueeze(0) <= upper_key.unsqueeze(1) + ) # excluding [1, 3, 0, 0], last [1, 2, 256, 256] + ) + + return mask + + def process_prefixes(self, prefixes: torch.Tensor): + bs, prefix_len = prefixes.shape + taken_len = torch.full((bs,), prefix_len, device=DEVICE) + mask = self.get_mask_by_prefix(prefixes, taken_len) + # self.keys.unsqueeze(0) = 1 x bs + # lower_key.unsqueeze(1), upper_key.unsqueeze(1) = bs x 1 + num_items_in_range = (mask).sum(dim=1) + return num_items_in_range # bs + + def get_outer_inner_levels(self, semantic_ids: torch.Tensor, items_to_query: int): + bs, K = semantic_ids.shape + num_items = torch.stack( + [self.process_prefixes(semantic_ids[:, : i + 1]) for i in range(K)], dim=1 + ) + num_items: torch.Tensor = torch.cat( + [ + torch.full( + (bs, 1), + self.total_items, + device=DEVICE, + ), + num_items, + ], + dim=1, + ) + + # first idx from end where it > items_to_query + + forward_mask = (num_items > items_to_query).int() # bs x K + 1 + backward_mask = forward_mask.flip(1) # bs x K + 1 + outer_level = K - torch.argmax(backward_mask, dim=1) # bs + inner_level = outer_level + 1 # bs + + # ol & il - how long prefix take => get (> items_to_query & <= items_to_query) items + + assert (outer_level <= K).all() + + return num_items, outer_level, inner_level # bs x K + 1, bs, bs + + def get_scores(self, item_indices, idx, query_residuals_per_level): + bs = idx.shape[0] # batch_size + + # stored[n, i, :] = self.residuals_per_level[item_indices[n,i], idx_expanded[n,i], :] + stored = self.residuals_per_level[item_indices[None, :], idx[:, None], :] + + # Gather the corresponding query vectors for each row: + # query[n, :] = query_residuals_per_level[n, idx[n], :] + query = query_residuals_per_level[ + torch.arange(bs, device=item_indices.device), idx, : + ] # Shape [batch_size, D] + + # Dot products => shape [batch_size, total_items] + scores = torch.einsum("bnd,bd->bn", stored, query) + + return scores + + def get_closest_vectorized( + self, + outer_masks, # shape: [batch_size, total_items] (boolean) + inner_masks, # shape: [batch_size, total_items] (boolean) + outer_levels, # shape: [batch_size] + inner_levels, # shape: [batch_size] + query_residuals_per_level, # shape: [batch_size, K+1, embedding_dim] + items_to_query, + ): + device = outer_masks.device + bs, total_items = outer_masks.shape + + item_indices = torch.arange(total_items, device=device) + + guaranteed_scores = self.get_scores( + item_indices, + -(inner_levels + 1), + query_residuals_per_level, + ) + guaranteed_scores = torch.where( + inner_masks, guaranteed_scores, torch.tensor(float("-inf"), device=device) + ) # [batch_size, total_items] + + left_scores = self.get_scores( + item_indices, + -outer_levels, + query_residuals_per_level, + ) + left_masks = outer_masks & ~inner_masks + left_scores = torch.where( + left_masks, left_scores, torch.tensor(float("-inf"), device=device) + ) # [batch_size, total_items] + + _, guaranteed_indices = torch.topk( + guaranteed_scores, items_to_query, dim=1 + ) # [batch_size, items_to_query] + _, left_indices = torch.topk( + left_scores, items_to_query, dim=1 + ) # [batch_size, items_to_query] + + indices = torch.cat( + [guaranteed_indices, left_indices], dim=1 + ) # [batch_size, 2 * items_to_query] + + top_ids = self.raw_item_ids[indices][ + :, :items_to_query + ] # [batch_size, items_to_query] + + return top_ids + + def query( + self, semantic_ids: torch.Tensor, residuals: torch.Tensor, items_to_query: int + ): + bs, K = semantic_ids.shape + + assert K == self.K, "Semantic ids must have same number of levels as the trie" + + num_items, outer_levels, inner_levels = self.get_outer_inner_levels( + semantic_ids, items_to_query + ) # bs x K + 1, bs, bs + + # print(num_items.shape, outer_levels.shape, inner_levels.shape) + # print(num_items, outer_levels, inner_levels) + + taken_outer_prefixes = semantic_ids * ( + torch.arange(K, device=DEVICE).expand(bs, K) < outer_levels.unsqueeze(1) + ) + taken_inner_prefixes = semantic_ids * ( + torch.arange(K, device=DEVICE).expand(bs, K) < inner_levels.unsqueeze(1) + ) + + outer_masks = self.get_mask_by_prefix( + taken_outer_prefixes, outer_levels + ) # bs, total_items + inner_masks = self.get_mask_by_prefix( + taken_inner_prefixes, inner_levels + ) # bs, total_items + + # print(inner_masks.shape, outer_masks.shape) + # print(inner_masks, outer_masks) + + inner_levels_max_mask = inner_levels == self.K + 1 + inner_levels[inner_levels_max_mask] = self.K + inner_masks[inner_levels_max_mask] = outer_masks[inner_levels_max_mask] + + assert ( + num_items[torch.arange(bs), outer_levels] == outer_masks.sum(dim=1) + ).all() + assert ( + num_items[torch.arange(bs), inner_levels] == inner_masks.sum(dim=1) + ).all() + + assert (outer_masks.sum(dim=1) > items_to_query).all() + # assert (inner_masks.sum(dim=1) <= items_to_query).all() # can be false if collisions + + assert (inner_masks <= outer_masks).all() + + query_residuals_per_level = self.get_residuals_per_level( + semantic_ids, residuals + ) + + raw_item_ids = self.get_closest_vectorized( + outer_masks, + inner_masks, + outer_levels, + inner_levels, + query_residuals_per_level, + items_to_query, + ) + + return raw_item_ids + + +if __name__ == "__main__": + embedding_dim = 512 # Embedding size + config = json.load(open("../configs/train/tiger_train_config.json")) + config = config["model"] + rqvae_config = json.load(open(config["rqvae_train_config_path"])) + rqvae_config["model"]["should_init_codebooks"] = False + rqvae_model = RqVaeModel.create_from_config(rqvae_config["model"]) + rqvae_model.load_state_dict( + torch.load(config["rqvae_checkpoint_path"], weights_only=True) + ) + rqvae_model.eval() + + trie = Trie(rqvae_model) + alphabet_size = 6 + + N = 12101 + K = 3 + # make tensor of size N x K + # of ([1, 2, 3], [1, 2, 3], [1, 2, 3], ...) + a = torch.arange(K).repeat(20, 1) + b = torch.arange(K + 1, K + K + 1).repeat(20, 1) + semantic_ids = torch.cat([a, b], dim=0) + residuals = torch.randn(semantic_ids.shape[0], embedding_dim) + trie.build_tree_structure( + semantic_ids, residuals, torch.arange(semantic_ids.shape[0]) + ) + + items_to_query = 5 + batch_size = 1 + q_semantic_ids = semantic_ids[0].repeat(batch_size, 1) + # q_semantic_ids = torch.randint(0, alphabet_size, (batch_size, K), dtype=torch.int64) + q_residuals = torch.randn(batch_size, embedding_dim) + + total_time = 0 + n_exps = 1 + + for i in range(n_exps): + now = time.time() + item_ids = trie.query(q_semantic_ids, q_residuals, items_to_query) + print(semantic_ids[item_ids].shape) + print(q_semantic_ids.shape) + print(semantic_ids[item_ids] == q_semantic_ids) + assert item_ids.shape == (batch_size, items_to_query) + total_time += time.time() - now + + print(f"Time per query: {total_time / n_exps * 1000:.2f} ms") diff --git a/modeling/utils/__init__.py b/modeling/utils/__init__.py index c366aeb6..8b96ac52 100644 --- a/modeling/utils/__init__.py +++ b/modeling/utils/__init__.py @@ -10,8 +10,12 @@ import torch -DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') -# DEVICE = torch.device('cpu') +if torch.cuda.is_available(): + DEVICE = torch.device('cuda:0') +# elif torch.backends.mps.is_available(): +# DEVICE = torch.device("mps:0") +else: + DEVICE = torch.device('cpu') def parse_args(): diff --git a/notebooks/AmazonBeautyDatasetStatistics.ipynb b/notebooks/AmazonBeautyDatasetStatistics.ipynb index bfe19910..e010eecc 100644 --- a/notebooks/AmazonBeautyDatasetStatistics.ipynb +++ b/notebooks/AmazonBeautyDatasetStatistics.ipynb @@ -29,7 +29,7 @@ "outputs": [], "source": [ "path_to_df = '../data/Beauty/ratings_Beauty.csv'\n", - "df = pd.read_csv(path_to_df, names=['user_id', 'item_id', 'rating', 'timestamp'])" + "df = pd.read_csv(path_to_df, names=['raw_user_id', 'raw_item_id', 'rating', 'timestamp'])" ] }, { @@ -59,7 +59,7 @@ "metadata": {}, "outputs": [], "source": [ - "df.user_id.max(), df.user_id.unique().shape" + "df.raw_user_id.max(), df.raw_user_id.unique().shape" ] }, { @@ -69,7 +69,7 @@ "metadata": {}, "outputs": [], "source": [ - "df.user_id = pd.factorize(df.user_id)[0] + 1\n", + "df['user_id'] = pd.factorize(df.raw_user_id)[0] + 1\n", "df.user_id.min(), df.user_id.max(), df.user_id.unique().shape" ] }, @@ -80,7 +80,7 @@ "metadata": {}, "outputs": [], "source": [ - "df.item_id = pd.factorize(df.item_id)[0] + 1\n", + "df['item_id'] = pd.factorize(df.raw_item_id)[0] + 1\n", "df.item_id.min(), df.item_id.max(), df.item_id.unique().shape" ] }, @@ -376,6 +376,29 @@ " ] + [str(test_sample['next_interaction']['item_id'])]))\n", " f.write('\\n')" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "868d5db5", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import pandas as pd\n", + "\n", + "deduped_mapping = df.drop_duplicates(subset=['item_id', 'raw_item_id'])\n", + "\n", + "embs = torch.load('../data/df_with_embs.pt')\n", + "\n", + "merged = pd.merge(deduped_mapping, embs, 'inner', left_on='raw_item_id', right_on='asin')\n", + "merged['item_id'] = merged['item_id'].map(lambda x: item_mapping[x])\n", + " \n", + "assert len(merged) == len(merged.item_id.unique())\n", + "merged = merged.set_index('item_id')\n", + "\n", + "torch.save(merged, '../data/Beauty/data_full.pt')" + ] } ], "metadata": { diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..f32390d0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,38 @@ +[project] +name = "irec" +version = "0.1.0" +description = "IRec framework" +readme = "README.md" +requires-python = ">=3.13" +dependencies = [ + "faiss-cpu>=1", + "pandas>=2", + "scipy>=1", + "seaborn>=0.13.2", + "tensorboard>=2", + "torch>=2.7", + "transformers>=4.51", + "tqdm>=4", + "jupyter>=1", +] + +[tool.uv.sources] +torch = [ + { index = "pytorch-cu128", marker = "sys_platform != 'darwin'" }, + { index = "pytorch-cpu", marker = "sys_platform == 'darwin'" }, +] + +[[tool.uv.index]] +name = "pytorch-cu128" +url = "https://download.pytorch.org/whl/cu128" +explicit = true + +[[tool.uv.index]] +name = "pytorch-cpu" +url = "https://download.pytorch.org/whl/cpu" +explicit = true + +[dependency-groups] +dev = [ + "ruff>=0.11.4", +] diff --git a/review.md b/review.md new file mode 100644 index 00000000..1a862f85 --- /dev/null +++ b/review.md @@ -0,0 +1,105 @@ +# Review + +## Todos + +- Train dataset size: 16972 (in `sasrec`) +- level embeddings +- fix trie eval +- sos / bos embedding correct train fix +- positions = positions // self._semantic_id_length или reverse? +как именно учитываем codebook_post & item_pos (тот же порядок или inverted) +- как именно находим ближайшего при пересечении по embedding? (не понял о каком embedding речь) + +то есть: + +1) items (10) +2) semantic_ids (40) -> (1, 2, 3, 4) +3) predicting 11th item (next 4 semantic ids) +4) if single - ok +5) if several? -> dedup # let length be 5 in rqvae, dedup -> 4 + closest by dist +6) if nothing? take all by longest prefix (closest by L^2 / COS / dot) + +encoder -> (b_size x 40 x emb_dim) +target -> (b_size x 4) [(1, 2, 3, 4); (29, 6, 7, 4); ...] +decoder: (bos, 1, 2, 3) -> (1, 2, 3, 4) # causal mask so (bos -> 1), (bos, 1 -> 2), ... + \___ learnable embed + +## Fixed + +- next_item_pred / last_item_pred (какие задачи учим и как именно) # can be both tasks +- предсказываем item = предсказываем 4 semantic id? # yes +- как составить датасет для обучения # (map item seq -> semantic id seq) +- берем правдивые semantic id # yes +- у нас авторегрессионный next item prediction? # no (teacher learning) + +- fix dataset (take last max_seq items) (last_item fixed) +- single sample from single user (honest comparison) +- correct logits indexing with tgt_mask? (upper remark fixes) + +- posterior collapse (как будто все сваливается в один индекс в кодбуке) (fixed eval code) +- в Amazon датасете пофиг на rating? получается учитываются только implicit действия? # байтовый датасет (любое взаимодействие) +- TODO какой базовый класс использовать для seq2seq модели? (LastPred?) # use encoder from SequentialTorchModel +- TODO имя для модели (tiger) # tmp + +## Remarks + +- обязательно использование reinit unused clusters! (mark) + +## Links + +- [dataset](https://cseweb.ucsd.edu/~jmcauley/datasets/amazon/links.html) + +## Remarks + +- no biases on leave one out strategy (обрезаем по строго временному порогу) + +## Todo + +### Train full encoder-decoder + +- На чем обучать? То есть на каких данных запускать backward pass? +- train model + +### Collisions + +- fix collisisons +- last index = `KMeans(last residuals, n=|last codebook|)` - collision +- remainder = last embedding +- auto increment last index (check paper) +- Research last index aggregation + +#### possible collisions example + +- item1: 1 2 3 0 +- item2: 1 2 3 1 +- item3: 4 5 6 0/2 +- item4: 4 5 6 1/3 + +### Retreive + +- single item -> ok +- too many items -> get embeddings -> score. Softmax(collisions), torch.logsoftmax(logits) -> score -> argmax + +### Framework + +- positional emb for item & codebook +- splitting item ? + +### positional embeddings example + +- (000 111 222) - item +- (012 012 012) - codebook + +### Fixes in framework + +- user_id & codebook_ids -> repr ??? +- add last 'sequence' prediction, now only last item is supported +- dataloader (semantic ids lens) + +## TODO + +1) Tiger +2) SasRec +3) SasRec freezed (all on single board) +4) Tiger batched inference +5) Tiger honest embedding_dim