From 98d5f7773e03a64589eea6514fd92bf64b9efc23 Mon Sep 17 00:00:00 2001 From: Noname Untitled Date: Sun, 9 Nov 2025 01:31:35 +0300 Subject: [PATCH 1/5] MODELING V2 --- .../inference/bert4rec_inference_config.json | 88 -- .../inference/gru4rec_inference_config.json | 87 -- .../inference/light_gcn_inference_config.json | 78 -- configs/inference/pop_inference_config.json | 79 -- .../inference/random_inference_config.json | 78 -- .../inference/sasrec_inference_config.json | 88 -- configs/train/bert4rec_train_cls_config.json | 142 --- configs/train/bert4rec_train_config.json | 167 ---- configs/train/bert4rec_train_grid_config.json | 193 ---- configs/train/cl4srec_train_config.json | 135 --- configs/train/cl4srec_train_grid_config.json | 187 ---- configs/train/duorec_train_config.json | 176 ---- configs/train/duorec_train_grid_config.json | 196 ---- configs/train/gru4rec_train_config.json | 166 ---- configs/train/gru4rec_train_grid_config.json | 184 ---- configs/train/light_gcn_train_config.json | 158 --- .../train/light_gcn_train_grid_config.json | 177 ---- configs/train/mclsr_train_config.json | 229 ----- configs/train/ngcf_train_config.json | 158 --- configs/train/ngcf_train_grid_config.json | 177 ---- configs/train/pure_mf_train_config.json | 100 -- configs/train/s3rec_pretrain_config.json | 93 -- configs/train/s3rec_train_config.json | 150 --- configs/train/sasrec_ce_train_config.json | 146 --- .../train/sasrec_in_batch_train_config.json | 146 --- configs/train/sasrec_real_train_config.json | 217 ----- configs/train/sasrec_train_config.json | 179 ---- configs/train/sasrec_train_grid_config.json | 185 ---- pyproject.toml | 35 +- setup.py | 39 + src/irec/__init__.py | 14 + src/irec/callbacks/__init__.py | 58 +- src/irec/callbacks/base.py | 561 +++++------ src/irec/callbacks/logging.py | 151 +++ src/irec/callbacks/metrics.py | 175 ++++ src/irec/callbacks/model.py | 17 + src/irec/callbacks/profiler.py | 42 + src/irec/callbacks/stats.py | 73 ++ src/irec/callbacks/stopping.py | 74 ++ src/irec/callbacks/timer.py | 64 ++ src/irec/callbacks/train.py | 39 + src/irec/{scheduler => data}/__init__.py | 0 src/irec/data/base.py | 26 + src/irec/data/dataloader.py | 201 ++++ src/irec/data/transforms/__init__.py | 8 + src/irec/data/transforms/base.py | 74 ++ src/irec/dataloader/__init__.py | 8 - src/irec/dataloader/base.py | 43 - src/irec/dataloader/batch_processors.py | 38 - src/irec/dataset/__init__.py | 3 - src/irec/dataset/base.py | 921 ------------------ .../dataset/negative_samplers/__init__.py | 9 - src/irec/dataset/negative_samplers/base.py | 19 - src/irec/dataset/negative_samplers/popular.py | 44 - src/irec/dataset/negative_samplers/random.py | 28 - src/irec/dataset/samplers/__init__.py | 40 - src/irec/dataset/samplers/base.py | 48 - src/irec/dataset/samplers/cl4srec.py | 150 --- src/irec/dataset/samplers/duorec.py | 55 -- .../dataset/samplers/last_item_prediction.py | 50 - .../samplers/masked_item_prediction.py | 99 -- src/irec/dataset/samplers/mclsr.py | 80 -- .../dataset/samplers/next_item_prediction.py | 86 -- src/irec/dataset/samplers/pop.py | 65 -- src/irec/dataset/samplers/s3rec.py | 131 --- src/irec/infer.py | 101 -- src/irec/loss/__init__.py | 1 - src/irec/loss/base.py | 631 ------------ src/irec/metric/__init__.py | 7 - src/irec/metric/base.py | 202 ---- src/irec/models/__init__.py | 40 +- src/irec/models/base.py | 243 +---- src/irec/models/bert4rec.py | 129 --- src/irec/models/bert4rec_cls.py | 114 --- src/irec/models/cl4srec.py | 159 --- src/irec/models/deepfm.py | 0 src/irec/models/duorec.py | 172 ---- src/irec/models/graph_seq_rec.py | 313 ------ src/irec/models/gru4rec.py | 264 ----- src/irec/models/gtorec.py | 571 ----------- src/irec/models/lightgcn.py | 227 ----- src/irec/models/mclsr.py | 436 --------- src/irec/models/mrgsrec.py | 144 --- src/irec/models/ngcf.py | 244 ----- src/irec/models/old_rqvae.py | 136 +++ src/irec/models/pop.py | 46 - src/irec/models/pure_mf.py | 131 --- src/irec/models/pure_svd.py | 14 - src/irec/models/random.py | 41 - src/irec/models/rqvae.py | 260 +++++ src/irec/models/s3rec.py | 267 ----- src/irec/models/sasrec.py | 221 ----- src/irec/models/sasrec_ce.py | 108 -- src/irec/optimizer/__init__.py | 3 - src/irec/optimizer/base.py | 78 -- src/irec/pretrain.py | 118 --- src/irec/runners/__init__.py | 18 + src/irec/runners/base.py | 155 +++ src/irec/runners/inference.py | 45 + src/irec/runners/train.py | 85 ++ src/irec/scheduler/base.py | 5 - src/irec/train.py | 200 ---- src/irec/train_multiple.py | 185 ---- src/irec/utils.py | 16 + src/irec/utils/__init__.py | 150 --- src/irec/utils/grid_search.py | 60 -- src/irec/utils/registry.py | 79 -- src/irec/utils/tensorboards/__init__.py | 11 - .../utils/tensorboards/tensorboard_writers.py | 37 - 109 files changed, 2027 insertions(+), 12167 deletions(-) delete mode 100644 configs/inference/bert4rec_inference_config.json delete mode 100644 configs/inference/gru4rec_inference_config.json delete mode 100644 configs/inference/light_gcn_inference_config.json delete mode 100644 configs/inference/pop_inference_config.json delete mode 100644 configs/inference/random_inference_config.json delete mode 100644 configs/inference/sasrec_inference_config.json delete mode 100644 configs/train/bert4rec_train_cls_config.json delete mode 100644 configs/train/bert4rec_train_config.json delete mode 100644 configs/train/bert4rec_train_grid_config.json delete mode 100644 configs/train/cl4srec_train_config.json delete mode 100644 configs/train/cl4srec_train_grid_config.json delete mode 100644 configs/train/duorec_train_config.json delete mode 100644 configs/train/duorec_train_grid_config.json delete mode 100644 configs/train/gru4rec_train_config.json delete mode 100644 configs/train/gru4rec_train_grid_config.json delete mode 100644 configs/train/light_gcn_train_config.json delete mode 100644 configs/train/light_gcn_train_grid_config.json delete mode 100644 configs/train/mclsr_train_config.json delete mode 100644 configs/train/ngcf_train_config.json delete mode 100644 configs/train/ngcf_train_grid_config.json delete mode 100644 configs/train/pure_mf_train_config.json delete mode 100644 configs/train/s3rec_pretrain_config.json delete mode 100644 configs/train/s3rec_train_config.json delete mode 100644 configs/train/sasrec_ce_train_config.json delete mode 100644 configs/train/sasrec_in_batch_train_config.json delete mode 100644 configs/train/sasrec_real_train_config.json delete mode 100644 configs/train/sasrec_train_config.json delete mode 100644 configs/train/sasrec_train_grid_config.json create mode 100644 setup.py create mode 100644 src/irec/callbacks/logging.py create mode 100644 src/irec/callbacks/metrics.py create mode 100644 src/irec/callbacks/model.py create mode 100644 src/irec/callbacks/profiler.py create mode 100644 src/irec/callbacks/stats.py create mode 100644 src/irec/callbacks/stopping.py create mode 100644 src/irec/callbacks/timer.py create mode 100644 src/irec/callbacks/train.py rename src/irec/{scheduler => data}/__init__.py (100%) create mode 100644 src/irec/data/base.py create mode 100644 src/irec/data/dataloader.py create mode 100644 src/irec/data/transforms/__init__.py create mode 100644 src/irec/data/transforms/base.py delete mode 100644 src/irec/dataloader/__init__.py delete mode 100644 src/irec/dataloader/base.py delete mode 100644 src/irec/dataloader/batch_processors.py delete mode 100644 src/irec/dataset/__init__.py delete mode 100644 src/irec/dataset/base.py delete mode 100644 src/irec/dataset/negative_samplers/__init__.py delete mode 100644 src/irec/dataset/negative_samplers/base.py delete mode 100644 src/irec/dataset/negative_samplers/popular.py delete mode 100644 src/irec/dataset/negative_samplers/random.py delete mode 100644 src/irec/dataset/samplers/__init__.py delete mode 100644 src/irec/dataset/samplers/base.py delete mode 100644 src/irec/dataset/samplers/cl4srec.py delete mode 100644 src/irec/dataset/samplers/duorec.py delete mode 100644 src/irec/dataset/samplers/last_item_prediction.py delete mode 100644 src/irec/dataset/samplers/masked_item_prediction.py delete mode 100644 src/irec/dataset/samplers/mclsr.py delete mode 100644 src/irec/dataset/samplers/next_item_prediction.py delete mode 100644 src/irec/dataset/samplers/pop.py delete mode 100644 src/irec/dataset/samplers/s3rec.py delete mode 100644 src/irec/infer.py delete mode 100644 src/irec/loss/__init__.py delete mode 100644 src/irec/loss/base.py delete mode 100644 src/irec/metric/__init__.py delete mode 100644 src/irec/metric/base.py delete mode 100644 src/irec/models/bert4rec.py delete mode 100644 src/irec/models/bert4rec_cls.py delete mode 100644 src/irec/models/cl4srec.py delete mode 100644 src/irec/models/deepfm.py delete mode 100644 src/irec/models/duorec.py delete mode 100644 src/irec/models/graph_seq_rec.py delete mode 100644 src/irec/models/gru4rec.py delete mode 100644 src/irec/models/gtorec.py delete mode 100644 src/irec/models/lightgcn.py delete mode 100644 src/irec/models/mclsr.py delete mode 100644 src/irec/models/mrgsrec.py delete mode 100644 src/irec/models/ngcf.py create mode 100644 src/irec/models/old_rqvae.py delete mode 100644 src/irec/models/pop.py delete mode 100644 src/irec/models/pure_mf.py delete mode 100644 src/irec/models/pure_svd.py delete mode 100644 src/irec/models/random.py create mode 100644 src/irec/models/rqvae.py delete mode 100644 src/irec/models/s3rec.py delete mode 100644 src/irec/models/sasrec.py delete mode 100644 src/irec/models/sasrec_ce.py delete mode 100644 src/irec/optimizer/__init__.py delete mode 100644 src/irec/optimizer/base.py delete mode 100644 src/irec/pretrain.py create mode 100644 src/irec/runners/__init__.py create mode 100644 src/irec/runners/base.py create mode 100644 src/irec/runners/inference.py create mode 100644 src/irec/runners/train.py delete mode 100644 src/irec/scheduler/base.py delete mode 100644 src/irec/train.py delete mode 100644 src/irec/train_multiple.py create mode 100644 src/irec/utils.py delete mode 100644 src/irec/utils/__init__.py delete mode 100644 src/irec/utils/grid_search.py delete mode 100644 src/irec/utils/registry.py delete mode 100644 src/irec/utils/tensorboards/__init__.py delete mode 100644 src/irec/utils/tensorboards/tensorboard_writers.py diff --git a/configs/inference/bert4rec_inference_config.json b/configs/inference/bert4rec_inference_config.json deleted file mode 100644 index 67b6c3f7..00000000 --- a/configs/inference/bert4rec_inference_config.json +++ /dev/null @@ -1,88 +0,0 @@ -{ - "pred_prefix": "logits", - "label_prefix": "labels", - "experiment_name": "bert4rec_beauty_grid_0-5_0-1__", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "mask_prob": 0.5, - "type": "masked_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "bert4rec", - "sequence_prefix": "item", - "labels_prefix": "labels", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.1, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } -} diff --git a/configs/inference/gru4rec_inference_config.json b/configs/inference/gru4rec_inference_config.json deleted file mode 100644 index d315de62..00000000 --- a/configs/inference/gru4rec_inference_config.json +++ /dev/null @@ -1,87 +0,0 @@ -{ - "pred_prefix": "logits", - "label_prefix": "labels", - "experiment_name": "gru4rec_beauty_grid__0-2__", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "gru4rec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.2, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } -} diff --git a/configs/inference/light_gcn_inference_config.json b/configs/inference/light_gcn_inference_config.json deleted file mode 100644 index de4a925c..00000000 --- a/configs/inference/light_gcn_inference_config.json +++ /dev/null @@ -1,78 +0,0 @@ -{ - "pred_prefix": "predictions", - "label_prefix": "labels", - "experiment_name": "light_gcn_inference_beauty", - "dataset": { - "type": "graph", - "graph_dir_path": "./data/Beauty", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "light_gcn", - "user_prefix": "user", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_layers": 3, - "keep_prob": 1.0, - "dropout": 0.0, - "initializer_range": 0.02 - }, - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } -} diff --git a/configs/inference/pop_inference_config.json b/configs/inference/pop_inference_config.json deleted file mode 100644 index c3dad45b..00000000 --- a/configs/inference/pop_inference_config.json +++ /dev/null @@ -1,79 +0,0 @@ -{ - "pred_prefix": "predictions", - "label_prefix": "labels", - "experiment_name": "pop_inference_beauty", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "pop", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "pop", - "label_prefix": "labels", - "candidate_prefix": "candidates", - "counts_prefix": "candidates_counts" - }, - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } -} diff --git a/configs/inference/random_inference_config.json b/configs/inference/random_inference_config.json deleted file mode 100644 index bbec82fb..00000000 --- a/configs/inference/random_inference_config.json +++ /dev/null @@ -1,78 +0,0 @@ -{ - "pred_prefix": "predictions", - "label_prefix": "labels", - "experiment_name": "random_inference_beauty", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "random", - "label_prefix": "labels", - "candidate_prefix": "candidates" - }, - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } -} diff --git a/configs/inference/sasrec_inference_config.json b/configs/inference/sasrec_inference_config.json deleted file mode 100644 index c91e21bf..00000000 --- a/configs/inference/sasrec_inference_config.json +++ /dev/null @@ -1,88 +0,0 @@ -{ - "pred_prefix": "logits", - "label_prefix": "labels", - "experiment_name": "sasrec_beauty_grid__0-5__", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.5, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } -} diff --git a/configs/train/bert4rec_train_cls_config.json b/configs/train/bert4rec_train_cls_config.json deleted file mode 100644 index 56848e37..00000000 --- a/configs/train/bert4rec_train_cls_config.json +++ /dev/null @@ -1,142 +0,0 @@ -{ - "experiment_name": "bert4rec_beauty", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "last_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "bert4rec_cls", - "user_prefix": "user", - "sequence_prefix": "item", - "labels_prefix": "labels", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.2, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 1e-4 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "batch_logsoftmax", - "predictions_prefix": "predictions", - "candidates_prefix": "candidates" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/bert4rec_train_config.json b/configs/train/bert4rec_train_config.json deleted file mode 100644 index 9818ec9d..00000000 --- a/configs/train/bert4rec_train_config.json +++ /dev/null @@ -1,167 +0,0 @@ -{ - "experiment_name": "bert4rec_beauty", - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "mask_prob": 0.7, - "type": "masked_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "bert4rec", - "sequence_prefix": "item", - "labels_prefix": "labels", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.2, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "ce", - "predictions_prefix": "logits", - "labels_prefix": "labels", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/bert4rec_train_grid_config.json b/configs/train/bert4rec_train_grid_config.json deleted file mode 100644 index d1602cb0..00000000 --- a/configs/train/bert4rec_train_grid_config.json +++ /dev/null @@ -1,193 +0,0 @@ -{ - "start_from": 0, - "experiment_name": "bert4rec_beauty_grid", - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "masked_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataset_params": { - "samplers": { - "mask_prob": [ - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8 - ] - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "bert4rec", - "sequence_prefix": "item", - "labels_prefix": "labels", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "model_params": { - "dropout": [ - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8] - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "optimizer_params": { - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "ce", - "predictions_prefix": "logits", - "labels_prefix": "labels", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "loss_params": { - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/cl4srec_train_config.json b/configs/train/cl4srec_train_config.json deleted file mode 100644 index 0a4656b2..00000000 --- a/configs/train/cl4srec_train_config.json +++ /dev/null @@ -1,135 +0,0 @@ -{ - "experiment_name": "cl4srec_clothing_test_labels", - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Clothing", - "max_sequence_length": 50, - "samplers": { - "type": "cl4srec", - "negative_sampler_type": "random", - "item_crop_portion": 0.5, - "item_mask_portion": 0.5, - "item_reorder_portion": 0.5 - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "cl4srec", - "sequence_prefix": "item", - "fst_augmented_sequence_prefix": "fst_augmented_item", - "snd_augmented_sequence_prefix": "snd_augmented_item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "labels_prefix": "labels", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.6, - "activation": "relu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "ce", - "predictions_prefix": "logits", - "labels_prefix": "labels", - "output_prefix": "cl4srec_beauty_downstream_loss" - }, - { - "type": "fps", - "fst_embeddings_prefix": "fst_aug_sequence_representation", - "snd_embeddings_prefix": "snd_aug_sequence_representation", - "tau": 1.0, - "weight": 0.2, - "normalize_embeddings": false, - "output_prefix": "cl4srec_beauty_contrastive_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "cl4srec_beauty_downstream_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "cl4srec_beauty_contrastive_loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/cl4srec_train_grid_config.json b/configs/train/cl4srec_train_grid_config.json deleted file mode 100644 index a61010a7..00000000 --- a/configs/train/cl4srec_train_grid_config.json +++ /dev/null @@ -1,187 +0,0 @@ -{ - "start_from": 0, - "num_exps": 100, - "experiment_name": "cl4srec_ml_grid", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "MovieLens1M", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "cl4srec", - "negative_sampler_type": "random" - } - }, - "dataset_params": { - "samplers": { - "item_crop_portion": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - "item_mask_portion": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - "item_reorder_portion": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "cl4srec", - "sequence_prefix": "item", - "fst_augmented_sequence_prefix": "fst_augmented_item", - "snd_augmented_sequence_prefix": "snd_augmented_item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "num_layers": 2, - "num_heads": 2, - "embedding_dim": 64, - "dim_feedforward": 256, - "activation": "relu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "model_params": { - "dropout": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "optimizer_params": { - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "cl4srec", - "current_representation": "sequence_representation", - "all_items_representation": "all_items_representation", - "tau": 1.0, - "weight": 1.0, - "output_prefix": "cl4srec_ml_downstream_loss" - }, - { - "type": "fps", - "fst_embeddings_prefix": "fst_aug_sequence_representation", - "snd_embeddings_prefix": "snd_aug_sequence_representation", - "tau": 1.0, - "output_prefix": "cl4srec_ml_contrastive_loss" - } - ], - "output_prefix": "cl4srec_ml_loss" - }, - "loss_params": { - "losses": [ - {}, - { - "weight": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], - "normalize_embeddings": [true, false], - "add_negatives": [false, true] - } - ] - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "cl4srec_ml_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "cl4srec_ml_downstream_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "cl4srec_ml_contrastive_loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/duorec_train_config.json b/configs/train/duorec_train_config.json deleted file mode 100644 index 9d381820..00000000 --- a/configs/train/duorec_train_config.json +++ /dev/null @@ -1,176 +0,0 @@ -{ - "experiment_name": "duorec_beauty_test", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "duorec", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "duorec", - "negative_sampler_type": "random" - } - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "duorec", - "sequence_prefix": "item", - "augmented_sequence_prefix": "semantic_similar_item", - "labels_prefix": "labels", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.4, - "activation": "relu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "ce", - "predictions_prefix": "logits", - "labels_prefix": "labels", - "output_prefix": "duorec_downstream_loss" - }, - { - "type": "duorec_ssl", - "original_embedding_prefix": "sequence_representation", - "dropout_embedding_prefix": "similar_sequence_representation", - "similar_embedding_prefix": "augmented_sequence_representation", - "weight": 0.4, - "normalize_embeddings": false, - "tau": 1.0, - "output_prefix": "duorec_reg_loss" - } - ], - "output_prefix": "duorec_loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "duorec_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "duorec_downstream_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "duorec_reg_loss_dropout" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "duorec_reg_loss_ssl" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "duorec_reg_loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/duorec_train_grid_config.json b/configs/train/duorec_train_grid_config.json deleted file mode 100644 index 4a753976..00000000 --- a/configs/train/duorec_train_grid_config.json +++ /dev/null @@ -1,196 +0,0 @@ -{ - "start_from": 0, - "num_exps": 100, - "experiment_name": "duorec_ml_grid", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "duorec", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "MovieLens1M", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "duorec", - "negative_sampler_type": "random" - } - } - }, - "dataset_params": { - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "duorec", - "sequence_prefix": "item", - "augmented_sequence_prefix": "semantic_similar_item", - "labels_prefix": "labels", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "activation": "relu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "model_params": { - "dropout": [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8] - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "optimizer_params": { - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "ce", - "predictions_prefix": "logits", - "labels_prefix": "labels", - "output_prefix": "duorec_ml_downstream_loss" - }, - { - "type": "duorec_ssl", - "fst_embeddings_prefix": "sequence_representation", - "snd_embeddings_prefix": "similar_sequence_representation", - "output_prefix": "duorec_ml_ssl_droupout" - }, - { - "type": "duorec_ssl", - "fst_embeddings_prefix": "similar_sequence_representation", - "snd_embeddings_prefix": "augmented_sequence_representation", - "output_prefix": "duorec_ml_ssl_augmented" - } - ], - "output_prefix": "duorec_ml_loss" - }, - "loss_params": { - "losses": [ - {}, - { - "tau": [0.1, 1.0, 10.0], - "weight": [0.1, 0.2, 0.3, 0.4, 0.5, 1.0], - "normalize_embeddings": [false, true] - }, - { - "tau": [0.1, 1.0, 10.0], - "weight": [0.1, 0.2, 0.3, 0.4, 0.5, 1.0], - "normalize_embeddings": [false, true] - } - ] - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "duorec_ml_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "duorec_ml_downstream_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "duorec_ml_ssl_droupout" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "duorec_ml_ssl_augmented" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/gru4rec_train_config.json b/configs/train/gru4rec_train_config.json deleted file mode 100644 index 5b19a886..00000000 --- a/configs/train/gru4rec_train_config.json +++ /dev/null @@ -1,166 +0,0 @@ -{ - "experiment_name": "gru4rec_beauty", - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "gru4rec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "bpr", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/gru4rec_train_grid_config.json b/configs/train/gru4rec_train_grid_config.json deleted file mode 100644 index 2e3db004..00000000 --- a/configs/train/gru4rec_train_grid_config.json +++ /dev/null @@ -1,184 +0,0 @@ -{ - "start_from": 0, - "experiment_name": "gru4rec_beauty_grid", - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataset_params": { - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "gru4rec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_layers": 2, - "dim_feedforward": 256, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "model_params": { - "dropout": [ - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9] - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "optimizer_params": { - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "bpr", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "loss_params": { - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/light_gcn_train_config.json b/configs/train/light_gcn_train_config.json deleted file mode 100644 index d374e920..00000000 --- a/configs/train/light_gcn_train_config.json +++ /dev/null @@ -1,158 +0,0 @@ -{ - "experiment_name": "light_gcn_beauty_test", - "dataset": { - "type": "graph", - "graph_dir_path": "./data/Beauty", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "light_gcn", - "user_prefix": "user", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_layers": 3, - "dropout": 0.2, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "bpr", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "lightgcn_downstream_loss" - }, - { - "type": "regularization", - "prefix": ["item_embeddings"], - "weight": 1e-2, - "output_prefix": "lightgcn_regularization_loss" - } - ], - "output_prefix": "lightgcn_loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "lightgcn_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "lightgcn_downstream_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "lightgcn_regularization_loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/light_gcn_train_grid_config.json b/configs/train/light_gcn_train_grid_config.json deleted file mode 100644 index 5d661f33..00000000 --- a/configs/train/light_gcn_train_grid_config.json +++ /dev/null @@ -1,177 +0,0 @@ -{ - "start_from": 0, - "experiment_name": "light_gcn_beauty_grid", - "best_metric": "validation/ndcg@20", - "num_exps": 60, - "dataset": { - "type": "graph", - "graph_dir_path": "./data/Beauty", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - } - }, - "dataset_params": { - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "light_gcn", - "user_prefix": "user", - "positive_prefix": "positive", - "embedding_dim": 64, - "initializer_range": 0.02 - }, - "model_params": { - "num_layers": [ - 2, - 3], - "dropout": [ - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8] - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "optimizer_params": { - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "bpr", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "lightgcn_downstream_loss" - }, - { - "type": "regularization", - "prefix": [ - "item_embeddings"], - "output_prefix": "lightgcn_regularization_loss" - } - ], - "output_prefix": "loss" - }, - "loss_params": { - "losses": [ - { - }, - { - "weight": [ - 1e-5, - 1e-4, - 1e-3, - 1e-2, - 1e-1] - } - ] - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/mclsr_train_config.json b/configs/train/mclsr_train_config.json deleted file mode 100644 index 6c0689bd..00000000 --- a/configs/train/mclsr_train_config.json +++ /dev/null @@ -1,229 +0,0 @@ -{ - "experiment_name": "mclsr_Real_Toys_hit", - "use_wandb": true, - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "graph", - "use_user_graph": true, - "use_item_graph": true, - "neighborhood_size": 50, - "graph_dir_path": "./data/Real_Toys", - "dataset": { - "type": "mclsr", - "path_to_data_dir": "./data", - "name": "Real_Toys", - "max_sequence_length": 20, - "samplers": { - "num_negatives_val": 1280, - "num_negatives_train": 1280, - "type": "mclsr", - "negative_sampler_type": "random" - } - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 128, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 128, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "mclsr", - "sequence_prefix": "item", - "user_prefix": "user", - "labels_prefix": "labels", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_graph_layers": 2, - "dropout": 0.3, - "layer_norm_eps": 1e-9, - "graph_dropout": 0.3, - "initializer_range": 0.02, - "alpha": 0.5 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sampled_softmax", - "queries_prefix": "combined_representation", - "positive_prefix": "label_representation", - "negative_prefix": "negative_representation", - "output_prefix": "downstream_loss", - "weight": 1.0 - }, - { - "type": "fps", - "fst_embeddings_prefix": "sequential_representation", - "snd_embeddings_prefix": "graph_representation", - "output_prefix": "contrastive_interest_loss", - "weight": 1.0, - "temperature": 0.5 - }, - { - "type": "fps", - "fst_embeddings_prefix": "user_graph_user_embeddings", - "snd_embeddings_prefix": "common_graph_user_embeddings", - "output_prefix": "contrastive_user_feature_loss", - "weight": 0.05, - "temperature": 0.5 - }, - { - "type": "fps", - "fst_embeddings_prefix": "item_graph_item_embeddings", - "snd_embeddings_prefix": "common_graph_item_embeddings", - "output_prefix": "contrastive_item_feature_loss", - "weight": 0.05, - "temperature": 0.5 - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "downstream_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "contrastive_interest_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "contrastive_user_feature_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "contrastive_item_feature_loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "mclsr-ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "mclsr-ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "mclsr-ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "mclsr-ndcg", - "k": 50 - }, - "recall@5": { - "type": "mclsr-recall", - "k": 5 - }, - "recall@10": { - "type": "mclsr-recall", - "k": 10 - }, - "recall@20": { - "type": "mclsr-recall", - "k": 20 - }, - "recall@50": { - "type": "mclsr-recall", - "k": 50 - }, - "hit@20": { - "type": "mclsr-hit", - "k": 20 - }, - "hit@50": { - "type": "mclsr-hit", - "k": 50 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "mclsr-ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "mclsr-ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "mclsr-ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "mclsr-ndcg", - "k": 50 - }, - "recall@5": { - "type": "mclsr-recall", - "k": 5 - }, - "recall@10": { - "type": "mclsr-recall", - "k": 10 - }, - "recall@20": { - "type": "mclsr-recall", - "k": 20 - }, - "recall@50": { - "type": "mclsr-recall", - "k": 50 - }, - "hit@20": { - "type": "mclsr-hit", - "k": 20 - }, - "hit@50": { - "type": "mclsr-hit", - "k": 50 - } - } - } - ] - } -} diff --git a/configs/train/ngcf_train_config.json b/configs/train/ngcf_train_config.json deleted file mode 100644 index 5aa26d5b..00000000 --- a/configs/train/ngcf_train_config.json +++ /dev/null @@ -1,158 +0,0 @@ -{ - "experiment_name": "ngcf_beauty_test", - "dataset": { - "type": "graph", - "graph_dir_path": "./data/Beauty", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "ngcf", - "user_prefix": "user", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_layers": 3, - "dropout": 0.2, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "bpr", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "lightgcn_downstream_loss" - }, - { - "type": "regularization", - "prefix": ["item_embeddings"], - "weight": 1e-2, - "output_prefix": "lightgcn_regularization_loss" - } - ], - "output_prefix": "lightgcn_loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "lightgcn_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "lightgcn_downstream_loss" - }, - { - "type": "metric", - "on_step": 1, - "loss_prefix": "lightgcn_regularization_loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/ngcf_train_grid_config.json b/configs/train/ngcf_train_grid_config.json deleted file mode 100644 index 13029016..00000000 --- a/configs/train/ngcf_train_grid_config.json +++ /dev/null @@ -1,177 +0,0 @@ -{ - "start_from": 0, - "experiment_name": "ngcf_beauty_grid", - "best_metric": "validation/ndcg@20", - "num_exps": 60, - "dataset": { - "type": "graph", - "graph_dir_path": "./data/Beauty", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - } - }, - "dataset_params": { - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "ngcf", - "user_prefix": "user", - "positive_prefix": "positive", - "embedding_dim": 64, - "initializer_range": 0.02 - }, - "model_params": { - "dropout": [ - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8], - "num_layers": [ - 2, - 3] - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "optimizer_params": { - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "bpr", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "ngcf_downstream_loss" - }, - { - "type": "regularization", - "prefix": [ - "item_embeddings"], - "output_prefix": "ngcf_regularization_loss" - } - ], - "output_prefix": "loss" - }, - "loss_params": { - "losses": [ - { - }, - { - "weight": [ - 1e-5, - 1e-4, - 1e-3, - 1e-2, - 1e-1] - } - ] - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/pure_mf_train_config.json b/configs/train/pure_mf_train_config.json deleted file mode 100644 index f1623d11..00000000 --- a/configs/train/pure_mf_train_config.json +++ /dev/null @@ -1,100 +0,0 @@ -{ - "experiment_name": "matrix_factorization_beauty_test", - "train_epochs_num": 10, - "dataset": { - "type": "sequence", - "name": "Beauty", - "path_to_data_dir": "./data", - "max_sequence_length": 200, - "samplers": { - "num_negatives_train": 100, - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "pure_mf", - "user_prefix": "user", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 256, - "initializer_range": 0.02 - }, - "loss": { - "type": "bpr", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "loss" - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 32, - "pred_prefix": "scores", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/s3rec_pretrain_config.json b/configs/train/s3rec_pretrain_config.json deleted file mode 100644 index 0c83936c..00000000 --- a/configs/train/s3rec_pretrain_config.json +++ /dev/null @@ -1,93 +0,0 @@ -{ - "experiment_name": "pretrain_s3rec_beauty_test_new", - "train_epochs_num": 100, - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "s3rec_pretrain", - "negative_sampler_type": "random", - "mask_prob": 0.2 - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "s3rec", - "is_training": false, - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "sequence_segment_prefix": "item_segment", - "positive_segment_prefix": "positive_segment", - "negative_segment_prefix": "negative_segment", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.6, - "activation": "relu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - } - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "s3rec_pretrain", - "positive_prefix": "positive_representation", - "negative_prefix": "negative_representation", - "representation_prefix": "current_representation", - "output_prefix": "mip_loss", - "weight": 1.0 - }, - { - "type": "s3rec_pretrain", - "positive_prefix": "positive_segment_representation", - "negative_prefix": "negative_segment_representation", - "representation_prefix": "current_segment_representation", - "output_prefix": "sp_loss", - "weight": 0.0 - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - } - ] - } -} diff --git a/configs/train/s3rec_train_config.json b/configs/train/s3rec_train_config.json deleted file mode 100644 index e9bbdb31..00000000 --- a/configs/train/s3rec_train_config.json +++ /dev/null @@ -1,150 +0,0 @@ -{ - "experiment_name": "s3rec_beauty_train", - "checkpoint": "pretrain_s3rec_beauty_test_new_final_state", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "s3rec", - "is_training": true, - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "sequence_segment_prefix": "item_segment", - "positive_segment_prefix": "positive_segment", - "negative_segment_prefix": "negative_segment", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.6, - "activation": "relu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec", - "positive_prefix": "positive_embeddings", - "negative_prefix": "negative_embeddings", - "representation_prefix": "current_embeddings", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/sasrec_ce_train_config.json b/configs/train/sasrec_ce_train_config.json deleted file mode 100644 index ce17ae3f..00000000 --- a/configs/train/sasrec_ce_train_config.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "experiment_name": "sasrec_test", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "use_ce": true, - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec", - "positive_prefix": "positive_embeddings", - "negative_prefix": "negative_embeddings", - "representation_prefix": "current_embeddings", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/sasrec_in_batch_train_config.json b/configs/train/sasrec_in_batch_train_config.json deleted file mode 100644 index 9950f313..00000000 --- a/configs/train/sasrec_in_batch_train_config.json +++ /dev/null @@ -1,146 +0,0 @@ -{ - "experiment_name": "sasrec_in_batch_test", - "best_metric": "eval/ndcg@20", - "dataset": { - "type": "scientific", - "path_to_data_dir": "data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "num_negatives_val": 100, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec_in_batch", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "use_ce": true, - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sampled_softmax", - "queries_prefix": "query_embeddings", - "positive_prefix": "positive_embeddings", - "negative_prefix": "negative_embeddings", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - } - } - } - ] - } -} diff --git a/configs/train/sasrec_real_train_config.json b/configs/train/sasrec_real_train_config.json deleted file mode 100644 index b04a5a08..00000000 --- a/configs/train/sasrec_real_train_config.json +++ /dev/null @@ -1,217 +0,0 @@ -{ - "experiment_name": "sasrec_real_beauty", - "best_metric": "validation/ndcg@100", - "dataset": { - "type": "scientific", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 100, - "samplers": { - "num_negatives_train": 1, - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 2048, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec_real", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec_real", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "ndcg", - "k": 50 - }, - "ndcg@100": { - "type": "ndcg", - "k": 100 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "recall@50": { - "type": "recall", - "k": 50 - }, - "recall@100": { - "type": "recall", - "k": 100 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - }, - "coverage@50": { - "type": "coverage", - "k": 50 - }, - "coverage@100": { - "type": "coverage", - "k": 100 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "ndcg", - "k": 50 - }, - "ndcg@100": { - "type": "ndcg", - "k": 100 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "recall@50": { - "type": "recall", - - "k": 50 - }, - "recall@100": { - "type": "recall", - "k": 100 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - }, - "coverage@50": { - "type": "coverage", - "k": 50 - }, - "coverage@100": { - "type": "coverage", - "k": 100 - } - } - } - ] - } -} diff --git a/configs/train/sasrec_train_config.json b/configs/train/sasrec_train_config.json deleted file mode 100644 index ac35763d..00000000 --- a/configs/train/sasrec_train_config.json +++ /dev/null @@ -1,179 +0,0 @@ -{ - "experiment_name": "sasrec_Real_Toys_hit", - "use_wandb": true, - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "sasrec_comparison", - "path_to_data_dir": "./data", - "name": "Real_Toys", - "max_sequence_length": 20, - "train_sampler": { - "type": "next_item_prediction", - "negative_sampler_type": "random", - "num_negatives_train": 0 - }, - "eval_sampler": { - "type": "mclsr" - } - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 128, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 128, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "dropout": 0.3, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 1.0 - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "mclsr-ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "mclsr-ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "mclsr-ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "mclsr-ndcg", - "k": 50 - }, - "recall@5": { - "type": "mclsr-recall", - "k": 5 - }, - "recall@10": { - "type": "mclsr-recall", - "k": 10 - }, - "recall@20": { - "type": "mclsr-recall", - "k": 20 - }, - "recall@50": { - "type": "mclsr-recall", - "k": 50 - }, - "hit@20": { - "type": "mclsr-hit", - "k": 20 - }, - "hit@50": { - "type": "mclsr-hit", - "k": 50 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "predictions", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "mclsr-ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "mclsr-ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "mclsr-ndcg", - "k": 20 - }, - "ndcg@50": { - "type": "mclsr-ndcg", - "k": 50 - }, - "recall@5": { - "type": "mclsr-recall", - "k": 5 - }, - "recall@10": { - "type": "mclsr-recall", - "k": 10 - }, - "recall@20": { - "type": "mclsr-recall", - "k": 20 - }, - "recall@50": { - "type": "mclsr-recall", - "k": 50 - }, - "hit@20": { - "type": "mclsr-hit", - "k": 20 - }, - "hit@50": { - "type": "mclsr-hit", - "k": 50 - } - } - } - ] - } -} \ No newline at end of file diff --git a/configs/train/sasrec_train_grid_config.json b/configs/train/sasrec_train_grid_config.json deleted file mode 100644 index acf86b61..00000000 --- a/configs/train/sasrec_train_grid_config.json +++ /dev/null @@ -1,185 +0,0 @@ -{ - "start_from": 0, - "experiment_name": "sasrec_beauty_grid", - "best_metric": "validation/ndcg@20", - "dataset": { - "type": "sequence", - "path_to_data_dir": "./data", - "name": "Beauty", - "max_sequence_length": 50, - "samplers": { - "type": "next_item_prediction", - "negative_sampler_type": "random" - } - }, - "dataset_params": { - }, - "dataloader": { - "train": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": true, - "shuffle": true - }, - "validation": { - "type": "torch", - "batch_size": 256, - "batch_processor": { - "type": "basic" - }, - "drop_last": false, - "shuffle": false - } - }, - "model": { - "type": "sasrec", - "sequence_prefix": "item", - "positive_prefix": "positive", - "negative_prefix": "negative", - "candidate_prefix": "candidates", - "embedding_dim": 64, - "num_heads": 2, - "num_layers": 2, - "dim_feedforward": 256, - "activation": "gelu", - "layer_norm_eps": 1e-9, - "initializer_range": 0.02 - }, - "model_params": { - "dropout": [ - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9] - }, - "optimizer": { - "type": "basic", - "optimizer": { - "type": "adam", - "lr": 0.001 - }, - "clip_grad_threshold": 5.0 - }, - "optimizer_params": { - }, - "loss": { - "type": "composite", - "losses": [ - { - "type": "sasrec", - "positive_prefix": "positive_scores", - "negative_prefix": "negative_scores", - "output_prefix": "downstream_loss" - } - ], - "output_prefix": "loss" - }, - "loss_params": { - }, - "callback": { - "type": "composite", - "callbacks": [ - { - "type": "metric", - "on_step": 1, - "loss_prefix": "loss" - }, - { - "type": "validation", - "on_step": 64, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - }, - { - "type": "eval", - "on_step": 256, - "pred_prefix": "logits", - "labels_prefix": "labels", - "metrics": { - "ndcg@5": { - "type": "ndcg", - "k": 5 - }, - "ndcg@10": { - "type": "ndcg", - "k": 10 - }, - "ndcg@20": { - "type": "ndcg", - "k": 20 - }, - "recall@5": { - "type": "recall", - "k": 5 - }, - "recall@10": { - "type": "recall", - "k": 10 - }, - "recall@20": { - "type": "recall", - "k": 20 - }, - "coverage@5": { - "type": "coverage", - "k": 5 - }, - "coverage@10": { - "type": "coverage", - "k": 10 - }, - "coverage@20": { - "type": "coverage", - "k": 20 - } - } - } - ] - } -} diff --git a/pyproject.toml b/pyproject.toml index 5f4075fd..c06fd718 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,35 +1,12 @@ [build-system] -requires = ["setuptools>=64.0"] +requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "irec" -description = "Graph Sequentional recomendation system framework" -version = "1.0" -requires-python = ">=3.12" - -dependencies = [ - "numpy==1.26.*", - "torch==2.4.*", - "tqdm==4.66.*", - "scipy==1.14.*", - "pandas==2.2.*", - "polars==1.27.*", - "matplotlib==3.9.*", - "tensorboard==2.19.*", - "wandb==0.19.*", - "jupyterlab==4.4.*", - "ipykernel==6.29.*", - "notebook==7.4.*", -] - -[project.optional-dependencies] -dev = [ - "mypy==1.11.*", - "ruff==0.11.*", - "types-tqdm==4.67.*", - "scipy-stubs==1.15.*", +version = "1.0.0" +description = "Framework for R&D RecSys projects" +authors = [ + {name = "Vladimir Baikalov", email = "nonameuntitled159@gmail.com"} ] - -[project.scripts] -train="irec.train:main" +requires-python = ">=3.12" \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..17668574 --- /dev/null +++ b/setup.py @@ -0,0 +1,39 @@ +from setuptools import setup, find_packages +from pathlib import Path + +install_requires = [ + 'numpy', + 'lmdb', + 'loguru', + 'murmurhash', + 'scikit-learn', + 'tensorboard', + 'tqdm', + 'onnx', + 'brotli', + 'lz4', + 'mpi4py', + 'pynvml', + 'zstandard', +] + +with open(Path(__file__).parent / 'README.md', encoding='utf-8') as f: + long_description = f.read() + + +setup( + name='irec', + version='1.0.0', + package_dir={'': 'src'}, + description='Framework for R&D RecSys projects', + author='Vladimir Baikalov', + author_email='nonameuntitled159@gmail.com', + url='https://github.com/CTLab-ITMO/IRec', + packages=find_packages( + where='src', + exclude=['tests', 'tests.*', '*.tests', '*.tests.*'] + ), + install_requires=install_requires, + python_requires='>=3.12', + +) \ No newline at end of file diff --git a/src/irec/__init__.py b/src/irec/__init__.py index e69de29b..d3de3311 100644 --- a/src/irec/__init__.py +++ b/src/irec/__init__.py @@ -0,0 +1,14 @@ +from loguru import logger + +try: + import torch +except ImportError: + logger.error('torch is not available, please install torch to use irec framework') +else: + + __all__ = [ + 'callbacks', + 'data', + 'runners', + 'utils' + ] diff --git a/src/irec/callbacks/__init__.py b/src/irec/callbacks/__init__.py index 45b5b205..1111adc7 100644 --- a/src/irec/callbacks/__init__.py +++ b/src/irec/callbacks/__init__.py @@ -1,15 +1,47 @@ -from .base import ( - BaseCallback, - CompositeCallback, - EvalCallback, - InferenceCallback, - ValidationCallback, -) +from .base import * +from .train import * + + +from irec.callbacks.base import Callback, BatchCallback, Composite +from irec.callbacks.logging import Logger, LoggingCallback, TensorboardLogger +from irec.callbacks.metrics import BatchMetrics, LambdaMetrics, MetricAccumulator, Accumulator, MeanAccumulator, Validation +from irec.callbacks.model import LoadModel +from irec.callbacks.profiler import Profiler +from irec.callbacks.stats import Thermometer +from irec.callbacks.stopping import EarlyStopping, StopAfterNumSteps +from irec.callbacks.timer import CpuTimer, MeasureStepTime, MeasureLoadingTime, MeasureTotalStepTime +from irec.callbacks.train import TrainingCallback, ClipGradient + __all__ = [ - 'BaseCallback', - 'CompositeCallback', - 'EvalCallback', - 'InferenceCallback', - 'ValidationCallback', -] + 'Callback', + 'BatchCallback', + 'TrainingCallback', + 'Composite', + + 'Logger', + 'LoggingCallback', + 'TensorboardLogger', + + 'BatchMetrics', + 'LambdaMetrics', + 'MetricAccumulator', + 'Validation', + + 'Accumulator', + 'MeanAccumulator', + + 'LoadModel', + + 'Profiler', + 'Thermometer', + 'EarlyStopping', + 'StopAfterNumSteps', + + 'CpuTimer', + 'MeasureStepTime', + 'MeasureLoadingTime', + 'MeasureTotalStepTime', + + 'ClipGradient', +] \ No newline at end of file diff --git a/src/irec/callbacks/base.py b/src/irec/callbacks/base.py index 81271f59..569a04ac 100644 --- a/src/irec/callbacks/base.py +++ b/src/irec/callbacks/base.py @@ -1,344 +1,221 @@ -from irec.metric import BaseMetric, StatefullMetric - -import irec.utils -from irec.utils import MetaParent, create_logger - -import numpy as np -import os -import torch -from pathlib import Path - -logger = create_logger(name=__name__) - - -class BaseCallback(metaclass=MetaParent): - def __init__( - self, - model, - train_dataloader, - validation_dataloader, - eval_dataloader, - optimizer, - ): - self._model = model - self._train_dataloader = train_dataloader - self._validation_dataloader = validation_dataloader - self._eval_dataloader = eval_dataloader - self._optimizer = optimizer - - def __call__(self, inputs, step_num): - raise NotImplementedError - - -class MetricCallback(BaseCallback, config_name='metric'): - def __init__( - self, - model, - train_dataloader, - validation_dataloader, - eval_dataloader, - optimizer, - on_step, - metrics, - loss_prefix, - ): - super().__init__( - model, - train_dataloader, - validation_dataloader, - eval_dataloader, - optimizer, - ) - self._on_step = on_step - self._loss_prefix = loss_prefix - self._metrics = metrics if metrics is not None else {} - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], - metrics=config.get('metrics', None), - loss_prefix=config['loss_prefix'], - ) - - def __call__(self, inputs, step_num): - if step_num % self._on_step == 0: - for metric_name, metric_function in self._metrics.items(): - metric_value = metric_function( - ground_truth=inputs[ - self._model.schema['ground_truth_prefix'] - ], - predictions=inputs[ - self._model.schema['predictions_prefix'] - ], - ) - - irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar( - 'train/{}'.format(metric_name), - metric_value, - step_num, - ) - - irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar( - 'train/{}'.format(self._loss_prefix), - inputs[self._loss_prefix], - step_num, - ) - irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.flush() - - -class CheckpointCallback(BaseCallback, config_name='checkpoint'): - def __init__( - self, - model, - train_dataloader, - validation_dataloader, - eval_dataloader, - optimizer, - on_step, - save_path, - model_name, - ): - super().__init__( - model, - train_dataloader, - validation_dataloader, - eval_dataloader, - optimizer, - ) - self._on_step = on_step - self._save_path = Path(os.path.join(save_path, model_name)) - if self._save_path.exists(): - logger.warning( - 'Checkpoint path `{}` is already exists!'.format( - self._save_path, - ), - ) - else: - self._save_path.mkdir(parents=True, exist_ok=True) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], - save_path=config['save_path'], - model_name=config['model_name'], - ) - - def __call__(self, inputs, step_num): - if step_num % self._on_step == 0: - logger.debug('Saving model state on step {}...'.format(step_num)) - torch.save( - { - 'step_num': step_num, - 'model_state_dict': self._model.state_dict(), - 'optimizer_state_dict': self._optimizer.state_dict(), - }, - os.path.join( - self._save_path, - 'checkpoint_{}.pth'.format(step_num), - ), - ) - logger.debug('Saving done!') - - -class InferenceCallback(BaseCallback): - def __init__( - self, - model, - train_dataloader, - validation_dataloader, - eval_dataloader, - optimizer, - on_step, - pred_prefix, - labels_prefix, - metrics=None, - loss_prefix=None, - ): - super().__init__( - model, - train_dataloader, - validation_dataloader, - eval_dataloader, - optimizer, - ) - self._on_step = on_step - self._metrics = metrics if metrics is not None else {} - self._pred_prefix = pred_prefix - self._labels_prefix = labels_prefix - self._loss_prefix = loss_prefix - - @classmethod - def create_from_config(cls, config, **kwargs): - metrics = { - metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs) - for metric_name, metric_cfg in config['metrics'].items() - } - - return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], - metrics=metrics, - pred_prefix=config['pred_prefix'], - labels_prefix=config['labels_prefix'], - ) - - def __call__(self, inputs, step_num): - if step_num % self._on_step == 0: # TODO Add time monitoring - logger.debug(f'Running {self._get_name()} on step {step_num}...') - running_params = {} - for metric_name, metric_function in self._metrics.items(): - running_params[metric_name] = [] - if self._loss_prefix is not None: - running_params[self._loss_prefix] = [] - - self._model.eval() - with torch.no_grad(): - for batch in self._get_dataloader(): - for key, value in batch.items(): - batch[key] = value.to(irec.utils.DEVICE) - - batch[self._pred_prefix] = self._model(batch) - - for key, values in batch.items(): - batch[key] = values.cpu() - - for metric_name, metric_function in self._metrics.items(): - running_params[metric_name].extend( - metric_function( - inputs=batch, - pred_prefix=self._pred_prefix, - labels_prefix=self._labels_prefix, - ), - ) - - if self._loss_prefix is not None: - running_params[self._loss_prefix] += batch[ - self._loss_prefix - ].item() - - for metric_name, metric_function in self._metrics.items(): - if isinstance(metric_function, StatefullMetric): - running_params[metric_name] = metric_function.reduce( - running_params[metric_name], - ) - - for label, value in running_params.items(): - inputs[f'{self._get_name()}/{label}'] = np.mean(value) - irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.add_scalar( - f'{self._get_name()}/{label}', - np.mean(value), - step_num, - ) - irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.flush() - - logger.debug( - f'Running {self._get_name()} on step {step_num} is done!', - ) - - def _get_name(self): - return self.config_name - - def _get_dataloader(self): - raise NotImplementedError - - -class ValidationCallback(InferenceCallback, config_name='validation'): - @classmethod - def create_from_config(cls, config, **kwargs): - metrics = { - metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs) - for metric_name, metric_cfg in config['metrics'].items() - } - - return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], - metrics=metrics, - pred_prefix=config['pred_prefix'], - labels_prefix=config['labels_prefix'], - ) - - def _get_dataloader(self): - return self._validation_dataloader - - -class EvalCallback(InferenceCallback, config_name='eval'): - @classmethod - def create_from_config(cls, config, **kwargs): - metrics = { - metric_name: BaseMetric.create_from_config(metric_cfg, **kwargs) - for metric_name, metric_cfg in config['metrics'].items() - } - - return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - on_step=config['on_step'], - metrics=metrics, - pred_prefix=config['pred_prefix'], - labels_prefix=config['labels_prefix'], - ) - - def _get_dataloader(self): - return self._eval_dataloader - - -class CompositeCallback(BaseCallback, config_name='composite'): - def __init__( - self, - model, - train_dataloader, - validation_dataloader, - eval_dataloader, - optimizer, - callbacks, - ): - super().__init__( - model, - train_dataloader, - validation_dataloader, - eval_dataloader, - optimizer, - ) - self._callbacks = callbacks +import functools +import inspect +import types + +from functools import cached_property +from typing import Any, Callable, Union + + +class Callback: + """ + Runners have an ability to trigger different methods during execution. + + Context presumes that there should be some callbacks that work before the current one and produce necessary input. + Runner state is necessary for some callbacks and represents literal runner state and step state + + Typical hooks include: + * before_run + * load_snapshot + * runner step with specific callbacks + * after_step + * save_snapshot + * save_snapshot + * after_run + """ + def state_dict(self): + """ Callback state_dict is used for snapshotting / checkpointing. """ + return {} + + def load_state_dict(self, state_dict): + """ Should be able to load the state_dict it provided with method .state_dict """ + pass + + def save_snapshot(self, runner): + pass + + def load_snapshot(self, runner): + pass + + def before_run(self, runner): + pass + + def after_step(self, runner, context): + pass + + def after_run(self, runner, context): + pass + + declared_events = frozenset({before_run, load_snapshot, after_step, save_snapshot, after_run}) + + @cached_property + def implemented_events(self): + return frozenset({ + event + for event in self.declared_events + if getattr(self, event.__name__) != types.MethodType(event, self) + }) - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - model=kwargs['model'], - train_dataloader=kwargs['train_dataloader'], - validation_dataloader=kwargs['validation_dataloader'], - eval_dataloader=kwargs['eval_dataloader'], - optimizer=kwargs['optimizer'], - callbacks=[ - BaseCallback.create_from_config(cfg, **kwargs) - for cfg in config['callbacks'] - ], - ) - - def __call__(self, inputs, step_num): + def every_num_steps(self, num_steps): + return EveryNumSteps(self, num_steps) + + def ignore_if(self, predicate): + return Callback() if predicate else self + + +class BatchCallback(Callback): + def before_load(self, runner, context): + pass + + def before_process_batch(self, runner, context): + pass + + declared_events = Callback.declared_events | frozenset({before_load, before_process_batch}) + + +class Composite(Callback): + def __init__(self, *callbacks, declared_events=None): + super().__init__() + self._callbacks = callbacks + self._declared_events = frozenset({event for callback in self._callbacks for event in callback.declared_events}) + if declared_events is not None: + self._declared_events = frozenset(declared_events) + if not self.implemented_events.issubset(self._declared_events): + raise TypeError(f'Not declared events were found in callbacks: {self.implemented_events - self._declared_events}') + for event in self._declared_events: + if hasattr(self, event.__name__) and event not in Callback.declared_events: + raise TypeError(f'Event {event} conflict in CompositeCallback') + setattr(self, event.__name__, functools.partial(self._emit, event)) + + @property + def callbacks(self): + return self._callbacks + + def __len__(self): + return len(self._callbacks) + + def __getitem__(self, index): + return self._callbacks[index] + + def state_dict(self): + return [callback.state_dict() for callback in self._callbacks] + + def load_state_dict(self, state_dict): + for callback, state in zip(self._callbacks, state_dict): + callback.load_state_dict(state) + + @property + def declared_events(self): + return self._declared_events + + @cached_property + def implemented_events(self): + return frozenset({event for callback in self._callbacks for event in callback.implemented_events}) + + def _emit(self, event, runner, *args, **kwargs): for callback in self._callbacks: - callback(inputs, step_num) + if event in callback.declared_events: + getattr(callback, event.__name__)(runner, *args, **kwargs) + + +class EveryNumSteps(Callback): + def __init__(self, callback, num_steps): + super().__init__() + self._callback = callback + self._num_steps = num_steps + for event in self.declared_events: + if hasattr(self, event.__name__) and event not in Callback.declared_events: + raise TypeError(f'Event {event} conflict in EveryNumSteps') + setattr(self, event.__name__, functools.partial(self._emit, event)) + + @property + def callback(self): + return self._callback + + def state_dict(self): + return self._callback.state_dict() + + def load_state_dict(self, state_dict): + return self._callback.load_state_dict(state_dict) + + @property + def declared_events(self): + return self._callback.declared_events + + @property + def implemented_events(self): + return self._callback.implemented_events + + def _emit(self, event, runner, *args, **kwargs): + if runner.global_step % self._num_steps == 0 or runner.global_finished: + getattr(self._callback, event.__name__)(runner, *args, **kwargs) + + +class LambdaCallback(Callback): + def __init__(self, function: Union[Callable[[], Any], Callable[..., Any]]): + super().__init__() + self._function = function + self._has_args = (len(inspect.signature(function).parameters) > 0) + self._check_signature(inspect.signature(function if self._has_args else self._emit)) # TODO try remove if else and pass `self._emit` + + def _emit(self, *args, **kwargs): + return self._function(*args, **kwargs) if self._has_args else self._function() + + def _check_signature(self, signature): + pass + + +class SaveSnapshot(LambdaCallback): + def save_snapshot(self, runner): + self._emit(runner) + + def _check_signature(self, signature): + signature.bind(None) + + +class LoadSnapshot(LambdaCallback): + def load_snapshot(self, runner): + self._emit(runner) + + def _check_signature(self, signature): + signature.bind(None) + + +class BeforeRun(LambdaCallback): + def before_run(self, runner): + self._emit(runner) + + def _check_signature(self, signature): + signature.bind(None) + + +class BeforeLoad(BatchCallback, LambdaCallback): + def before_load(self, runner, context): + self._emit(runner, context) + + def _check_signature(self, signature): + signature.bind(None, None) + + +class BeforeBatch(BatchCallback, LambdaCallback): + def before_batch(self, runner, context): + self._emit(runner, context) + + def _check_signature(self, signature): + signature.bind(None, None) + + +class AfterStep(LambdaCallback): + def after_step(self, runner, context): + self._emit(runner, context) + + def _check_signature(self, signature): + signature.bind(None, None) + + +class AfterRun(LambdaCallback): + def after_run(self, runner, context): + self._emit(runner, context) + + def _check_signature(self, signature): + signature.bind(None, None) + diff --git a/src/irec/callbacks/logging.py b/src/irec/callbacks/logging.py new file mode 100644 index 00000000..aa2febdc --- /dev/null +++ b/src/irec/callbacks/logging.py @@ -0,0 +1,151 @@ +import datetime +import os +import shutil + +from loguru import logger + +from torch.utils.tensorboard import SummaryWriter + +from irec.callbacks.base import Callback +from irec.callbacks.train import TrainingCallback +from irec.runners.base import Runner, RunnerContext +from irec.runners.train import TrainingRunner, TrainingRunnerContext + + +class LoggingCallback(TrainingCallback): + def before_run(self, runner: TrainingRunner): + logger.debug('Before run') + pass + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + logger.debug(f'After step {runner.global_step}') + + def after_run(self, runner: TrainingRunner, context: TrainingRunnerContext): + logger.debug('After run') + + def before_load(self, runner: TrainingRunner, context: TrainingRunnerContext): + logger.debug('Before load') + + def before_process_batch(self, runner: TrainingRunner, context: TrainingRunnerContext): + logger.debug('Before process batch') + + def before_optimizer(self, runner: TrainingRunner, context: TrainingRunnerContext): + logger.debug('Before optimizer') + + +class TensorboardWriter(SummaryWriter): + def __init__( + self, + experiment_name, + log_dir, + use_time=True + ): + self._experiment_name = experiment_name + os.makedirs(log_dir, exist_ok=True) + super().__init__( + log_dir=os.path.join( + log_dir, + f'{experiment_name}_{datetime.datetime.now().strftime("%Y-%m-%dT%H:%M" if use_time else "")}' + ) + ) + + def add_scalar(self, *args, **kwargs): + super().add_scalar(*args, **kwargs) + + +class TensorboardLogger(Callback): + def __init__(self, experiment_name, logdir, rewrite=False, step_factor=1): + super().__init__() + self._experiment_name = experiment_name + self._logdir = logdir + self._rewrite = rewrite + self._step_factor = step_factor + self._writer = None + self._runner = None + + def _setup_writer(self, global_step): + if self._rewrite and global_step == 1: + if os.path.exists(self._logdir): + shutil.rmtree(self._logdir) + if self._writer is None: + self._writer = TensorboardWriter(experiment_name=self._experiment_name, log_dir=self._logdir) + + def state_dict(self): + if self._writer is not None: + self._writer.flush() + return {} + + def load_state_dict(self, state_dict): + self.close_writer() + + def before_run(self, runner): + if self._runner is None: + self._runner = runner + + def after_step(self, runner: Runner, context: RunnerContext): + self._setup_writer(runner.global_step) + for key, value in context.metrics.items(): + self._writer.add_scalar(key, value, runner.global_step // self._step_factor) + + def after_run(self, runner, context): + if self._runner is runner: + self._runner = None + self.close_writer() + + def close_writer(self): + if self._writer is not None: + self._writer.flush() + self._writer.close() + self._writer = None + + +class Logger(Callback): + def __init__(self, logfile=None, name=None, step_factor=1): + super().__init__() + self._logfile = logfile + # self._writer = logfile if isinstance(logfile, io.IOBase) else None + self._writer = None + self._name = name + self._step_factor = step_factor + self._runner = None + + def state_dict(self): + if self._writer is not None: + self._writer.flush() + return {} + + def load_state_dict(self, state_dict): + self.close_writer() + + def before_run(self, runner): + if self._runner is None: + self._runner = runner + logger.info('Starting run') + + def after_step(self, runner: Runner, context: RunnerContext): + msg = [f'step {runner.global_step // self._step_factor}'] + if self._name is not None: + msg.insert(0, self._name) + for key, value in context.metrics.items(): + msg.append(f'{key} {value}') + if self._logfile is not None: + if self._writer is None: + self._writer = open(self._logfile, 'at') + msg.insert(0, datetime.datetime.now().isoformat(' ', 'milliseconds')) + self._writer.write(', '.join(msg) + '\n') + self._writer.flush() + else: + logger.info(', '.join(msg)) + + def after_run(self, runner, context): + if self._runner is runner: + self._runner = None + self.close_writer() + logger.info('Finishing run') + + def close_writer(self): + if self._writer is not None: + self._writer.flush() + if self._writer is not self._logfile: + self._writer.close() + self._writer = None diff --git a/src/irec/callbacks/metrics.py b/src/irec/callbacks/metrics.py new file mode 100644 index 00000000..40704be6 --- /dev/null +++ b/src/irec/callbacks/metrics.py @@ -0,0 +1,175 @@ +import math + +import torch +from loguru import logger + +from irec.callbacks.base import BatchCallback, Callback, LambdaCallback +from irec.callbacks.train import TrainingCallback + +from irec.runners.base import Runner, RunnerContext +from irec.runners.train import TrainingRunner, TrainingRunnerContext + +from irec.runners.inference import InferenceRunner + + +class BatchMetrics(TrainingCallback): + def __init__(self, metrics, name=None, separator='/'): + self._metrics = metrics + self._name = name + self._separator = separator + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + metrics = self._metrics(context.model_outputs, context.batch) + BatchMetrics.add_context_metrics(context, metrics, name=self._name, separator=self._separator) + + @staticmethod + def add_context_metrics(context: TrainingRunnerContext, metrics, *, name=None, separator='/'): + metrics = {name: metrics} if name is not None else metrics + metrics = BatchMetrics.flatten_nested_metrics(metrics, separator=separator) + + for metric_name, metric_value in metrics.items(): + if metric_name in context.metrics: + raise ValueError(f'Metric already exists: {metric_name}') + + assert isinstance(metric_value, list) or isinstance(metric_value, float) or isinstance(metric_value, int) + + context.metrics[metric_name] = metric_value + + @staticmethod + def flatten_nested_metrics(metrics, *, separator='/'): + if not isinstance(metrics, dict): + raise TypeError('If name is None metrics must return dict') + result = dict() + for metric_name, metric_value in metrics.items(): + if isinstance(metric_value, dict): + for key, value in BatchMetrics.flatten_nested_metrics(metric_value, separator=separator).items(): + result[metric_name + separator + key] = value + else: + result[metric_name] = metric_value + return result + + +class LambdaMetrics(LambdaCallback): + def __init__(self, function, name=None, separator='/'): + super().__init__(function) + self._name = name + self._separator = separator + + def after_step(self, runner: Runner, context: RunnerContext): + metrics = self._emit(runner, context) + BatchMetrics.add_context_metrics(context, metrics, name=self._name, separator=self._separator) + + +class Accumulator: + def accumulate(self, value): + pass + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass + + def reduce(self): + pass + + def clear(self): + pass + + +class MeanAccumulator(Accumulator): + def __init__(self): + super().__init__() + self._accumulated_values = [] + + def accumulate(self, values): + if isinstance(values, list): + self._accumulated_values.extend(values) + else: + self._accumulated_values.append(values) + + def state_dict(self): + return {'values': self._accumulated_values} + + def load_state_dict(self, state_dict): + self._accumulated_values = state_dict['values'] + + def reduce(self): + return sum(self._accumulated_values) / len(self._accumulated_values) + + def clear(self): + self._accumulated_values = [] + + +class MetricAccumulator(Callback): + def __init__( + self, + accumulators: dict[str, Accumulator], + *, + reset_every_num_steps=None, + ): + super().__init__() + self._accumulators = accumulators + self._reset_every_num_steps = reset_every_num_steps + + def state_dict(self): + state_dict = {} + for idx, accumulator in enumerate(self._accumulators.values()): + state_dict[idx] = accumulator.state_dict() + + def load_state_dict(self, state_dict): + for idx, accumulator in enumerate(self._accumulators.values()): + accumulator.load_state_dict(state_dict[idx]) + + def before_run(self, runner: Runner): + self.clear() + + def after_step(self, runner: Runner, context: RunnerContext): + for name, accumulator in self._accumulators.items(): + accumulator.accumulate(context.metrics[name]) + self.reduce(context) + if self._reset_every_num_steps is not None and runner.global_step % self._reset_every_num_steps == 0: + self.clear() + + def reduce(self, context: RunnerContext): + for name, accumulator in self._accumulators.items(): + context.metrics[name] = accumulator.reduce() + + def after_run(self, runner: Runner, context: RunnerContext): + if self._reset_every_num_steps is None: + self.reduce(context) + self.clear() + + def clear(self): + for accumulator in self._accumulators.values(): + accumulator.clear() + + +class Validation(TrainingCallback): + def __init__( + self, + dataset, + callbacks, + *, + model=None + ): + if hasattr(dataset, '__next__') and not hasattr(dataset, '__getitem__'): + raise TypeError(f'Dataset expected to be iterable but not iterator, got {type(dataset)}') + self._dataset = dataset + self._model = model + self._callbacks = callbacks + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + logger.info('Doing validation') + + inference_result = InferenceRunner( + model=(self._model if self._model is not None else runner.model), + dataset=self._dataset, + callbacks=self._callbacks + ).run() + + for name, value in inference_result.metrics.items(): + if name in context.metrics: + raise ValueError(f'Metric already exists: {name}') + context.metrics[name] = value + diff --git a/src/irec/callbacks/model.py b/src/irec/callbacks/model.py new file mode 100644 index 00000000..78f41384 --- /dev/null +++ b/src/irec/callbacks/model.py @@ -0,0 +1,17 @@ +from loguru import logger + +import torch + + +from irec.callbacks.train import TrainingCallback +from irec.runners.train import TrainingRunner + + +class LoadModel(TrainingCallback): + def __init__(self, model_path): + super().__init__() + self.model_path = model_path + + def before_run(self, runner: TrainingRunner): + runner.model.load_state_dict(torch.load(self.model_path, weights_only=True)) + logger.debug(f'Model {self.model_path} is loaded!') diff --git a/src/irec/callbacks/profiler.py b/src/irec/callbacks/profiler.py new file mode 100644 index 00000000..7dd65348 --- /dev/null +++ b/src/irec/callbacks/profiler.py @@ -0,0 +1,42 @@ +import torch +from loguru import logger + +from irec.callbacks.base import Callback +from irec.runners.base import Runner, RunnerContext + + +class Profiler(Callback): + def __init__(self, wait, warmup, active, logdir, worker_idx=0): + assert wait + warmup > 0, 'Should have atleast some warmup before profiling' + self._wait = wait + self._warmup = warmup + self._active = active + self._logdir = logdir + self._worker_idx = worker_idx + self._curr_step = 0 + self._profiler = None + + def before_run(self, runner: Runner): + if self._profiler is None: + logger.info('Creating profiler') + self._profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule( + wait=self._wait, + warmup=self._warmup, + active=self._active + ), + record_shapes=True, + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler(self._logdir, worker_name=f'worker{self._worker_idx}'), + profile_memory=True + ) + self._profiler.start() + + def after_step(self, runne: Runner, context: RunnerContext): + if self._profiler is not None: + self._profiler.step() + if self._curr_step > self._wait + self._warmup + self._active: + self._profiler.stop() + self._profiler = None + self._curr_step += 1 diff --git a/src/irec/callbacks/stats.py b/src/irec/callbacks/stats.py new file mode 100644 index 00000000..4e3f299d --- /dev/null +++ b/src/irec/callbacks/stats.py @@ -0,0 +1,73 @@ +import torch +from typing import Dict, Any + +from irec.callbacks.base import BatchCallback +from irec.runners.base import BatchRunner, BatchRunnerContext + + +class Thermometer(BatchCallback): + def __init__(self, stats=None, **modules: Dict[str, torch.nn.Module]): + super().__init__() + for module in modules.values(): + assert isinstance(module, torch.nn.Module) + self.modules = modules + self.hooks: Dict[str, torch.utils.hooks.RemovableHandle] = {} + self.stats = stats or ['max', 'min', 'mean', 'median', 'std'] + + def before_batch(self, runner: BatchRunner, context: BatchRunnerContext) -> None: + self._register_all(context) + + def after_step(self, runner: BatchRunner, context: BatchRunnerContext) -> None: + self._remove_all() + + def _register_all(self, context: BatchRunnerContext) -> None: + for name, layer in self.modules.items(): + if name not in self.hooks: + self.hooks[name] = layer.register_forward_hook( + self._make_hook(name, context) + ) + + def _remove_all(self) -> None: + for name, handle in list(self.hooks.items()): + handle.remove() + del self.hooks[name] + + def _calculate_stats(self, tensor: torch.Tensor, layer_name: str, context: Any) -> None: + if not isinstance(tensor, torch.Tensor): + raise TypeError(f"Expected torch.Tensor, got {type(tensor)}") + + for stat in self.stats: + if callable(stat): + value = stat(tensor) + stat_name = stat.__name__ + elif stat == 'max': + value = tensor.amax().detach() + stat_name = 'max' + elif stat == 'min': + value = tensor.amin().detach() + stat_name = 'min' + elif stat == 'mean': + value = tensor.mean().detach() + stat_name = 'mean' + elif stat == 'median': + value = tensor.median().detach() + stat_name = 'median' + elif stat == 'std': + value = tensor.std().detach() + stat_name = 'std' + else: + raise ValueError(f"Unknown statistic: {stat}") + + context.metrics[f"act/{layer_name}_{stat_name}"] = value + + def _make_hook(self, layer_name: str, context: Any): + def hook(_, __, out): + if isinstance(out, torch.Tensor): + tensor = out + elif isinstance(out, (tuple, list)) and len(out) > 0 and isinstance(out[0], torch.Tensor): + tensor = out[0] + else: + return + self._calculate_stats(tensor, layer_name, context) + + return hook diff --git a/src/irec/callbacks/stopping.py b/src/irec/callbacks/stopping.py new file mode 100644 index 00000000..3d1405fc --- /dev/null +++ b/src/irec/callbacks/stopping.py @@ -0,0 +1,74 @@ +import copy +from loguru import logger +import os + +import torch + +from irec.callbacks.base import Callback +from irec.runners.base import Runner, RunnerContext + + +# TODO сделать для сохранения модели отдельный callback + +class EarlyStopping(Callback): + def __init__( + self, + metric, + patience, + *, + minimize=True, + model_path=None, + ): + self._metric = metric + self._best_metric = None + self._minimize = minimize + + self._patience = patience + self._wait = 0 + + self._best_model_state_dict = None + self._model_path = model_path + + def state_dict(self): + return { + 'wait': self._wait, + 'best_metric': self._best_metric + } + + def load_state_dict(self, state_dict): + self._wait = state_dict['wait'] + self._best_metric = state_dict['best_metric'] + + def after_step(self, runner: Runner, context: RunnerContext): + assert self._metric in context.metrics + metric = context.metrics[self._metric] + if self._best_metric is None: + self._best_metric = metric + torch.save(runner.model.state_dict(), f'{self._model_path}_best_{round(self._best_metric, 4)}.pth') + else: + if (self._minimize and metric < self._best_metric) or (not self._minimize and metric > self._best_metric): + self._wait = 0 + old_metric = self._best_metric + self._best_metric = metric + # Saving new model + torch.save(runner.model.state_dict(), f'{self._model_path}_best_{round(self._best_metric, 4)}.pth') + # Deleting old model + if str(round(self._best_metric, 4)) != str(round(old_metric, 4)): + os.remove(f'{self._model_path}_best_{round(old_metric, 4)}.pth') + logger.info(f'New best value for {self._metric}: {self._best_metric:.4f}') + else: + self._wait += 1 + logger.info(f'Wait is increased to {self._wait}') + if self._wait == self._patience: + logger.info(f'Patience for {self._metric} is reached: couldn"t beat value {self._best_metric:.4f} for {self._wait} calls') + raise StopIteration + + +class StopAfterNumSteps(Callback): + def __init__(self, num_steps): + self._num_steps = num_steps + + def after_step(self, runner: Runner, context: RunnerContext): + if runner.global_step >= self._num_steps: + raise StopIteration + diff --git a/src/irec/callbacks/timer.py b/src/irec/callbacks/timer.py new file mode 100644 index 00000000..a855aad9 --- /dev/null +++ b/src/irec/callbacks/timer.py @@ -0,0 +1,64 @@ +import time + +from irec.callbacks.base import BatchCallback, Callback +from irec.runners.base import BatchRunner, BatchRunnerContext, Runner, RunnerContext + + +class CpuTimer: + #TODO: implement CudaTimer + def __init__(self): + super().__init__() + self._start = None + + def start(self): + assert self._start is None + self._start = time.perf_counter() + + def stop(self): + assert self._start is not None + result = (time.perf_counter() - self._start) * 1000. + self._start = None + return result + + +class MeasureStepTime(BatchCallback): + def __init__(self, name='time/step'): + super().__init__() + self._name = name + self._timer = CpuTimer() + + def before_batch(self, runner: BatchRunner, context: BatchRunnerContext): + self._timer.start() + + def after_step(self, runner: BatchRunner, context: BatchRunnerContext): + context.metrics[self._name] = self._timer.stop() + + +class MeasureLoadingTime(BatchCallback): + def __init__(self, name='time/load'): + super().__init__() + self._name = name + self._timer = CpuTimer() + + def before_load(self, runner: BatchRunner, context: BatchRunnerContext): + self._timer.start() + + def before_batch(self, runner: BatchRunner, context: BatchRunnerContext): + context.metrics[self._name] = self._timer.stop() + + +class MeasureTotalStepTime(Callback): + def __init__(self, name='time/total'): + super().__init__() + self._name = name + self._timer = CpuTimer() + + def before_run(self, runner: Runner): + self._timer.start() + + def after_step(self, runner: Runner, context: RunnerContext): + context.metrics[self._name] = self._timer.stop() + self._timer.start() + + def after_run(self, runner: Runner, context: RunnerContext): + self._timer.stop() diff --git a/src/irec/callbacks/train.py b/src/irec/callbacks/train.py new file mode 100644 index 00000000..e678aca2 --- /dev/null +++ b/src/irec/callbacks/train.py @@ -0,0 +1,39 @@ +import torch + +from irec.callbacks.base import BatchCallback, LambdaCallback +from irec.runners.train import TrainingRunner, TrainingRunnerContext + + +class TrainingCallback(BatchCallback): + def before_optimizer(self, runner: TrainingRunner, context: TrainingRunnerContext): + pass + + declared_events = BatchCallback.declared_events | frozenset({before_optimizer}) + + +class BeforeOptimizer(TrainingCallback, LambdaCallback): + def before_optimizer(self, runner, context): + self._emit(runner, context) + + def _check_signature(self, signature): + signature.bind(None, None) + + +class ClipGradient(TrainingCallback): + def __init__(self, parameters=None, value=1.0, *, name=None): + self._parameters = list(parameters) if parameters is not None else None + self._value = value + self._name = name + + @torch.no_grad() + def before_optimizer(self, runner: TrainingRunner, context: TrainingRunnerContext): + if self._parameters is not None: + norm_before_clip = torch.nn.utils.clip_grad_norm_(self._parameters, self._value) + elif not hasattr(runner.model, 'clip_grad_norm_'): + norm_before_clip = torch.nn.utils.clip_grad_norm_(runner.model.parameters(), self._value) + else: + norm_before_clip = runner.model.clip_grad_norm_(self._value) + if self._name is not None: + if self._name in context.metrics: + raise ValueError(f'Optimizer name "{self._name}" appears more than once in the list!') + context.metrics[self._name] = norm_before_clip.item() diff --git a/src/irec/scheduler/__init__.py b/src/irec/data/__init__.py similarity index 100% rename from src/irec/scheduler/__init__.py rename to src/irec/data/__init__.py diff --git a/src/irec/data/base.py b/src/irec/data/base.py new file mode 100644 index 00000000..3d1a17fc --- /dev/null +++ b/src/irec/data/base.py @@ -0,0 +1,26 @@ +class BaseDataset: + def __getitem__(self, idx): + raise NotImplemented + + def __len__(self): + raise NotImplemented + + def map(self, mapper): + return MapDataset( + dataset=self, + mapper=mapper + ) + + +class MapDataset(BaseDataset): + def __init__(self, dataset, mapper): + self.dataset = dataset + self.mapper = mapper + + def __getitem__(self, idx): + return self.mapper( + self.dataset[idx] + ) + + def __len__(self): + return len(self.dataset) \ No newline at end of file diff --git a/src/irec/data/dataloader.py b/src/irec/data/dataloader.py new file mode 100644 index 00000000..ca32eeea --- /dev/null +++ b/src/irec/data/dataloader.py @@ -0,0 +1,201 @@ +from loguru import logger + +import torch +import torch.distributed as dist +from torch.utils.data import Dataset, Sampler +from torch.utils.data.distributed import DistributedSampler +from torchdata.stateful_dataloader import StatefulDataLoader +from typing import Callable, Optional, Iterator, Any, Dict + + +class DataLoader: + def __init__( + self, + dataset: Dataset, + batch_size: int, + num_workers: int = 0, + shuffle: bool = False, + drop_last: bool = False, + sampler: Optional[Sampler] = None, + **dataloader_args + ): + self.dataset = dataset + self.batch_size = batch_size + self.num_workers = num_workers + self.shuffle = shuffle if sampler is None else False + self.drop_last = drop_last + self.sampler = sampler + self.dataloader_args = dataloader_args + + self._dataloader = None + + @property + def dataloader(self) -> StatefulDataLoader: + if self._dataloader is None: + self._dataloader = self._create_dataloader() + return self._dataloader + + def _collate_fn(self, samples): + return samples + + def _create_dataloader(self) -> StatefulDataLoader: + loader_kwargs = { + 'batch_size': self.batch_size, + 'num_workers': self.num_workers, + 'drop_last': self.drop_last, + 'collate_fn': self._collate_fn, + **self.dataloader_args + } + + if self.sampler is not None: + loader_kwargs['sampler'] = self.sampler + loader_kwargs['shuffle'] = False + else: + loader_kwargs['shuffle'] = self.shuffle + + return StatefulDataLoader(dataset=self.dataset, **loader_kwargs) + + def __iter__(self) -> Iterator: + for batch in self.dataloader: + yield batch + + def __len__(self) -> int: + return len(self.dataloader) + + def state_dict(self) -> Dict[str, Any]: + return {'dataloader': self.dataloader.state_dict()} + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.dataloader.load_state_dict(state_dict['dataloader']) + + def shards(self, world_size: int, rank: int, seed: int = 0) -> 'ShardedDataLoader': + return ShardedDataLoader( + dataset=self.dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=self.shuffle, + drop_last=self.drop_last, + rank=rank, + world_size=world_size, + seed=seed, + **self.dataloader_args + ) + + def repeat(self, num_epochs: int): + return RepeatedDataLoader(base_loader=self, num_epochs=num_epochs) + + def map(self, mapper): + return MappedDataloader(base_loader=self, mapper=mapper) + + +class ShardedDataLoader(DataLoader): + def __init__(self, dataset: Dataset, batch_size: int, + rank: int, world_size: int, + num_workers: int = 0, shuffle: bool = False, + drop_last: bool = False, seed: int = 0, + **kwargs): + self.rank = rank + self.world_size = world_size + self.seed = seed + self._current_epoch = 0 + + sampler = DistributedSampler( + dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=seed, + drop_last=drop_last + ) + + super().__init__( + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, # shuffle управляется sampler'ом + drop_last=drop_last, + sampler=sampler, + **kwargs + ) + + def set_epoch(self, epoch: int): + self._current_epoch = epoch + if isinstance(self.sampler, DistributedSampler): + self.sampler.set_epoch(epoch) + + def repeat(self, num_epochs: int): + return RepeatedDataLoader( + base_loader=self, + num_epochs=num_epochs + ) + + +class RepeatedDataLoader: + def __init__(self, base_loader: DataLoader, num_epochs: int): + self.base_loader = base_loader + self.num_epochs = num_epochs + self.current_epoch = 0 + + def __iter__(self) -> Iterator: + for epoch in range(self.current_epoch, self.num_epochs): + logger.debug(f'Starting epoch: {epoch + 1}') + + self.current_epoch = epoch + + if hasattr(self.base_loader, 'sampler') and hasattr(self.base_loader.sampler, 'set_epoch'): + self.base_loader.sampler.set_epoch(epoch) + + self.base_loader._dataloader = None + + for batch in self.base_loader: + yield batch + + self.current_epoch = epoch + 1 + + self.current_epoch = 0 + + def __len__(self) -> int: + return len(self.base_loader) * self.num_epochs + + def state_dict(self) -> Dict[str, Any]: + return { + 'base_loader': self.base_loader.state_dict(), + 'current_epoch': self.current_epoch, + 'num_epochs': self.num_epochs, + } + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.current_epoch = state_dict['current_epoch'] + self.num_epochs = state_dict['num_epochs'] + self.base_loader.load_state_dict(state_dict['base_loader']) + + def map(self, mapper: Callable) -> 'MappedDataloader': + return MappedDataloader(base_loader=self, mapper=mapper) + + def repeat(self, num_epochs: int) -> 'RepeatedDataLoader': + return RepeatedDataLoader(base_loader=self, num_epochs=num_epochs) + + +class MappedDataloader(DataLoader): + def __init__(self, base_loader: DataLoader, mapper: Callable): + self.base_loader = base_loader + self.mapper = mapper + + def __iter__(self) -> Iterator: + for batch in self.base_loader: + yield self.mapper(batch) + + def __len__(self) -> int: + return len(self.base_loader) + + def state_dict(self) -> Dict[str, Any]: + return self.base_loader.state_dict() + + def load_state_dict(self, state_dict: Dict[str, Any]): + self.base_loader.load_state_dict(state_dict) + + def map(self, mapper: Callable) -> 'MappedDataloader': + return MappedDataloader(base_loader=self, mapper=mapper) + + def repeat(self, num_epochs: int) -> 'RepeatedDataLoader': + return RepeatedDataLoader(base_loader=self, num_epochs=num_epochs) diff --git a/src/irec/data/transforms/__init__.py b/src/irec/data/transforms/__init__.py new file mode 100644 index 00000000..e16131a7 --- /dev/null +++ b/src/irec/data/transforms/__init__.py @@ -0,0 +1,8 @@ +from irec.data.transforms.base import Transform, Collate, ToDevice, ToTorch + +__all__ = [ + 'Transform', + 'Collate', + 'ToDevice', + 'ToTorch', +] \ No newline at end of file diff --git a/src/irec/data/transforms/base.py b/src/irec/data/transforms/base.py new file mode 100644 index 00000000..b6df0bd9 --- /dev/null +++ b/src/irec/data/transforms/base.py @@ -0,0 +1,74 @@ +import numpy as np +import torch + + +class Transform: + def __call__(self, sample): + raise NotImplemented + + +class Mapper(Transform): + def __call__(self, sample): + for k, v in sample.items(): + if isinstance(v, dict): + sample[k] = self.__call__(v) + else: + sample[k] = self._mapper(v) + return sample + + def _mapper(self, x): + raise NotImplemented + + +class ToDevice(Mapper): + def __init__(self, device): + self.device=device + + def _mapper(self, value): + assert isinstance(value, torch.Tensor) + return value.to(self.device) + + +class ToTorch(Mapper): + def _mapper(self, value): + if isinstance(value, np.ndarray): + return torch.from_numpy(value) + elif isinstance(value, list): + return torch.tensor(value) + elif isinstance(value, torch.Tensor): + return value + elif isinstance(value, int) or isinstance(value, float): + return torch.as_tensor(value) + else: + assert False + + +class Collate(Transform): + def __call__(self, batch): + assert batch and isinstance(batch, list), batch + processed_batch = {} + + for key in batch[0].keys(): + values = [sample[key] for sample in batch] + if isinstance(values[0], dict): + processed_batch[key] = self.__call__(values) + elif isinstance(values[0], np.ndarray): + processed_batch[key] = np.empty(shape=(0,), dtype=values[0].dtype) + values = [value for value in values if value.size > 0] + if len(values) > 0: + processed_batch[key] = np.concatenate(values) + elif isinstance(values[0], torch.Tensor): + processed_batch[key] = torch.empty(size=(0,), dtype=values[0].dtype) + values = [value for value in values if value.numel() > 0] + if len(values) > 0: + if values[0].ndim == 0: # These are numbers + processed_batch[key] = torch.stack(values) + else: + processed_batch[key] = torch.cat(values) + else: + processed_batch[key] = np.array(values) + return processed_batch + + + + diff --git a/src/irec/dataloader/__init__.py b/src/irec/dataloader/__init__.py deleted file mode 100644 index ef1642fc..00000000 --- a/src/irec/dataloader/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .base import BaseDataloader -from .batch_processors import BaseBatchProcessor, IdentityBatchProcessor - -__all__ = [ - 'BaseDataloader', - 'BaseBatchProcessor', - 'IdentityBatchProcessor', -] diff --git a/src/irec/dataloader/base.py b/src/irec/dataloader/base.py deleted file mode 100644 index 06fdefcc..00000000 --- a/src/irec/dataloader/base.py +++ /dev/null @@ -1,43 +0,0 @@ -import copy - -from irec.utils import MetaParent -from .batch_processors import BaseBatchProcessor - -import logging -from torch.utils.data import DataLoader - -logger = logging.getLogger(__name__) - - -class BaseDataloader(metaclass=MetaParent): - pass - - -class TorchDataloader(BaseDataloader, config_name='torch'): - def __init__(self, dataloader): - self._dataloader = dataloader - - def __iter__(self): - return iter(self._dataloader) - - def __len__(self): - return len(self._dataloader) - - @classmethod - def create_from_config(cls, config, **kwargs): - create_config = copy.deepcopy(config) - batch_processor = BaseBatchProcessor.create_from_config( - create_config.pop('batch_processor') - if 'batch_processor' in create_config - else {'type': 'identity'}, - ) - create_config.pop( - 'type', - ) # For passing as **config in torch DataLoader - return cls( - dataloader=DataLoader( - kwargs['dataset'], - collate_fn=batch_processor, - **create_config, - ), - ) diff --git a/src/irec/dataloader/batch_processors.py b/src/irec/dataloader/batch_processors.py deleted file mode 100644 index a8dbdde5..00000000 --- a/src/irec/dataloader/batch_processors.py +++ /dev/null @@ -1,38 +0,0 @@ -import torch -from irec.utils import MetaParent - - -class BaseBatchProcessor(metaclass=MetaParent): - def __call__(self, batch): - raise NotImplementedError - - -class IdentityBatchProcessor(BaseBatchProcessor, config_name='identity'): - def __call__(self, batch): - return torch.tensor(batch) - - -class BasicBatchProcessor(BaseBatchProcessor, config_name='basic'): - def __call__(self, batch): - processed_batch = {} - - for key in batch[0].keys(): - if key.endswith('.ids'): - prefix = key.split('.')[0] - assert '{}.length'.format(prefix) in batch[0] - - processed_batch[f'{prefix}.ids'] = [] - processed_batch[f'{prefix}.length'] = [] - - for sample in batch: - processed_batch[f'{prefix}.ids'].extend( - sample[f'{prefix}.ids'], - ) - processed_batch[f'{prefix}.length'].append( - sample[f'{prefix}.length'], - ) - - for part, values in processed_batch.items(): - processed_batch[part] = torch.tensor(values, dtype=torch.long) - - return processed_batch diff --git a/src/irec/dataset/__init__.py b/src/irec/dataset/__init__.py deleted file mode 100644 index 0f179964..00000000 --- a/src/irec/dataset/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base import BaseDataset - -__all__ = ['BaseDataset'] diff --git a/src/irec/dataset/base.py b/src/irec/dataset/base.py deleted file mode 100644 index 0d183fd4..00000000 --- a/src/irec/dataset/base.py +++ /dev/null @@ -1,921 +0,0 @@ -from collections import defaultdict - -from tqdm import tqdm - -from irec.dataset.samplers import TrainSampler, EvalSampler - -from irec.utils import MetaParent, DEVICE - -import pickle -import torch -import numpy as np -import scipy.sparse as sp -from scipy.sparse import csr_matrix - -import os -import logging - -logger = logging.getLogger(__name__) - - -class BaseDataset(metaclass=MetaParent): - def get_samplers(self): - raise NotImplementedError - - -class SequenceDataset(BaseDataset, config_name='sequence'): - def __init__( - self, - train_sampler, - validation_sampler, - test_sampler, - num_users, - num_items, - max_sequence_length, - ): - self._train_sampler = train_sampler - self._validation_sampler = validation_sampler - self._test_sampler = test_sampler - self._num_users = num_users - self._num_items = num_items - self._max_sequence_length = max_sequence_length - - @classmethod - def create_from_config(cls, config, **kwargs): - data_dir_path = os.path.join( - config['path_to_data_dir'], - config['name'], - ) - - train_dataset, train_max_user_id, train_max_item_id, train_seq_len = ( - cls._create_dataset( - dir_path=data_dir_path, - part='train', - max_sequence_length=config['max_sequence_length'], - use_cached=config.get('use_cached', False), - ) - ) - ( - validation_dataset, - valid_max_user_id, - valid_max_item_id, - valid_seq_len, - ) = cls._create_dataset( - dir_path=data_dir_path, - part='valid', - max_sequence_length=config['max_sequence_length'], - use_cached=config.get('use_cached', False), - ) - test_dataset, test_max_user_id, test_max_item_id, test_seq_len = ( - cls._create_dataset( - dir_path=data_dir_path, - part='test', - max_sequence_length=config['max_sequence_length'], - use_cached=config.get('use_cached', False), - ) - ) - - max_user_id = max( - [train_max_user_id, valid_max_user_id, test_max_user_id], - ) - max_item_id = max( - [train_max_item_id, valid_max_item_id, test_max_item_id], - ) - max_seq_len = max([train_seq_len, valid_seq_len, test_seq_len]) - - logger.info('Train dataset size: {}'.format(len(train_dataset))) - logger.info('Test dataset size: {}'.format(len(test_dataset))) - logger.info('Max user id: {}'.format(max_user_id)) - logger.info('Max item id: {}'.format(max_item_id)) - logger.info('Max sequence length: {}'.format(max_seq_len)) - - train_interactions = sum( - list(map(lambda x: len(x), train_dataset)), - ) # whole user history as a sample - valid_interactions = len( - validation_dataset, - ) # each new interaction as a sample - test_interactions = len( - test_dataset, - ) # each new interaction as a sample - logger.info( - '{} dataset sparsity: {}'.format( - config['name'], - (train_interactions + valid_interactions + test_interactions) - / max_user_id - / max_item_id, - ), - ) - - train_sampler = TrainSampler.create_from_config( - config['samplers'], - dataset=train_dataset, - num_users=max_user_id, - num_items=max_item_id, - ) - validation_sampler = EvalSampler.create_from_config( - config['samplers'], - dataset=validation_dataset, - num_users=max_user_id, - num_items=max_item_id, - ) - test_sampler = EvalSampler.create_from_config( - config['samplers'], - dataset=test_dataset, - num_users=max_user_id, - num_items=max_item_id, - ) - - return cls( - train_sampler=train_sampler, - validation_sampler=validation_sampler, - test_sampler=test_sampler, - num_users=max_user_id, - num_items=max_item_id, - max_sequence_length=max_seq_len, - ) - - @classmethod - def _create_dataset( - cls, - dir_path, - part, - max_sequence_length=None, - use_cached=False, - ): - max_user_id = 0 - max_item_id = 0 - max_sequence_len = 0 - - if use_cached and os.path.exists( - os.path.join(dir_path, '{}.pkl'.format(part)), - ): - logger.info( - f'Take cached dataset from {os.path.join(dir_path, "{}.pkl".format(part))}', - ) - - with open( - os.path.join(dir_path, '{}.pkl'.format(part)), - 'rb', - ) as dataset_file: - dataset, max_user_id, max_item_id, max_sequence_len = ( - pickle.load(dataset_file) - ) - else: - logger.info( - 'Cache is forecefully ignored.' - if not use_cached - else 'No cached dataset has been found.', - ) - logger.info( - f'Creating a dataset {os.path.join(dir_path, "{}.txt".format(part))}...', - ) - - dataset_path = os.path.join(dir_path, '{}.txt'.format(part)) - with open(dataset_path, 'r') as f: - data = f.readlines() - - sequence_info = cls._create_sequences(data, max_sequence_length) - ( - user_sequences, - item_sequences, - max_user_id, - max_item_id, - max_sequence_len, - ) = sequence_info - - dataset = [] - for user_id, item_ids in zip(user_sequences, item_sequences): - dataset.append( - { - 'user.ids': [user_id], - 'user.length': 1, - 'item.ids': item_ids, - 'item.length': len(item_ids), - }, - ) - - logger.info('{} dataset size: {}'.format(part, len(dataset))) - logger.info( - '{} dataset max sequence length: {}'.format( - part, - max_sequence_len, - ), - ) - - with open( - os.path.join(dir_path, '{}.pkl'.format(part)), - 'wb', - ) as dataset_file: - pickle.dump( - (dataset, max_user_id, max_item_id, max_sequence_len), - dataset_file, - ) - - return dataset, max_user_id, max_item_id, max_sequence_len - - @staticmethod - def _create_sequences(data, max_sample_len): - user_sequences = [] - item_sequences = [] - - max_user_id = 0 - max_item_id = 0 - max_sequence_length = 0 - - for sample in data: - sample = sample.strip('\n').split(' ') - item_ids = [int(item_id) for item_id in sample[1:]][ - -max_sample_len: - ] - user_id = int(sample[0]) - - max_user_id = max(max_user_id, user_id) - max_item_id = max(max_item_id, max(item_ids)) - max_sequence_length = max(max_sequence_length, len(item_ids)) - - user_sequences.append(user_id) - item_sequences.append(item_ids) - - return ( - user_sequences, - item_sequences, - max_user_id, - max_item_id, - max_sequence_length, - ) - - def get_samplers(self): - return ( - self._train_sampler, - self._validation_sampler, - self._test_sampler, - ) - - @property - def num_users(self): - return self._num_users - - @property - def num_items(self): - return self._num_items - - @property - def max_sequence_length(self): - return self._max_sequence_length - - @property - def meta(self): - return { - 'num_users': self.num_users, - 'num_items': self.num_items, - 'max_sequence_length': self.max_sequence_length, - } - - -class GraphDataset(BaseDataset, config_name='graph'): - def __init__( - self, - dataset, - graph_dir_path, - use_train_data_only=True, - use_user_graph=False, - use_item_graph=False, - neighborhood_size=None - ): - self._dataset = dataset - self._graph_dir_path = graph_dir_path - self._use_train_data_only = use_train_data_only - self._use_user_graph = use_user_graph - self._use_item_graph = use_item_graph - self._neighborhood_size = neighborhood_size - - self._num_users = dataset.num_users - self._num_items = dataset.num_items - - train_sampler, validation_sampler, test_sampler = ( - dataset.get_samplers() - ) - - ( - train_interactions, - train_user_interactions, - train_item_interactions, - ) = [], [], [] - - train_user_2_items = defaultdict(set) - train_item_2_users = defaultdict(set) - visited_user_item_pairs = set() - - for sample in train_sampler.dataset: - user_id = sample['user.ids'][0] - item_ids = sample['item.ids'] - - for item_id in item_ids: - if (user_id, item_id) not in visited_user_item_pairs: - train_interactions.append((user_id, item_id)) - train_user_interactions.append(user_id) - train_item_interactions.append(item_id) - - train_user_2_items[user_id].add(item_id) - train_item_2_users[item_id].add(user_id) - - visited_user_item_pairs.add((user_id, item_id)) - - # TODO create separated function - if not self._use_train_data_only: - for sample in validation_sampler.dataset: - user_id = sample['user.ids'][0] - item_ids = sample['item.ids'] - - for item_id in item_ids: - if (user_id, item_id) not in visited_user_item_pairs: - train_interactions.append((user_id, item_id)) - train_user_interactions.append(user_id) - train_item_interactions.append(item_id) - - train_user_2_items[user_id].add(item_id) - train_item_2_users[item_id].add(user_id) - - visited_user_item_pairs.add((user_id, item_id)) - - for sample in test_sampler.dataset: - user_id = sample['user.ids'][0] - item_ids = sample['item.ids'] - - for item_id in item_ids: - if (user_id, item_id) not in visited_user_item_pairs: - train_interactions.append((user_id, item_id)) - train_user_interactions.append(user_id) - train_item_interactions.append(item_id) - - train_user_2_items[user_id].add(item_id) - train_item_2_users[item_id].add(user_id) - - visited_user_item_pairs.add((user_id, item_id)) - - self._train_interactions = np.array(train_interactions) - self._train_user_interactions = np.array(train_user_interactions) - self._train_item_interactions = np.array(train_item_interactions) - - path_to_graph = os.path.join(graph_dir_path, 'general_graph.npz') - if os.path.exists(path_to_graph): - self._graph = sp.load_npz(path_to_graph) - else: - # place ones only when co-occurrence happens - user2item_connections = csr_matrix( - ( - np.ones(len(train_user_interactions)), - (train_user_interactions, train_item_interactions), - ), - shape=(self._num_users + 2, self._num_items + 2), - ) # (num_users + 2, num_items + 2), bipartite graph - self._graph = self.get_sparse_graph_layer( - user2item_connections, - self._num_users + 2, - self._num_items + 2, - biparite=True, - ) - sp.save_npz(path_to_graph, self._graph) - - self._graph = ( - self._convert_sp_mat_to_sp_tensor(self._graph) - .coalesce() - .to(DEVICE) - ) - - if self._use_user_graph: - path_to_user_graph = os.path.join(graph_dir_path, 'user_graph.npz') - if os.path.exists(path_to_user_graph): - self._user_graph = sp.load_npz(path_to_user_graph) - else: - user2user_interactions_fst = [] - user2user_interactions_snd = [] - visited_user_item_pairs = set() - visited_user_user_pairs = set() - - for user_id, item_id in tqdm( - zip( - self._train_user_interactions, - self._train_item_interactions, - ), - ): - if (user_id, item_id) in visited_user_item_pairs: - continue # process (user, item) pair only once - visited_user_item_pairs.add((user_id, item_id)) - - for connected_user_id in train_item_2_users[item_id]: - if ( - (user_id, connected_user_id) - in visited_user_user_pairs - or user_id == connected_user_id - ): - continue # add (user, user) to graph connections pair only once - visited_user_user_pairs.add( - (user_id, connected_user_id), - ) - - user2user_interactions_fst.append(user_id) - user2user_interactions_snd.append(connected_user_id) - - # (user, user) graph - user2user_connections = csr_matrix( - ( - np.ones(len(user2user_interactions_fst)), - ( - user2user_interactions_fst, - user2user_interactions_snd, - ), - ), - shape=(self._num_users + 2, self._num_users + 2), - ) - print(self._neighborhood_size) - if self._neighborhood_size is not None: - user2user_connections = self._filter_matrix_by_top_k(user2user_connections, self._neighborhood_size) - - self._user_graph = self.get_sparse_graph_layer( - user2user_connections, - self._num_users + 2, - self._num_users + 2, - biparite=False, - ) - sp.save_npz(path_to_user_graph, self._user_graph) - - self._user_graph = ( - self._convert_sp_mat_to_sp_tensor(self._user_graph) - .coalesce() - .to(DEVICE) - ) - else: - self._user_graph = None - - if self._use_item_graph: - path_to_item_graph = os.path.join(graph_dir_path, 'item_graph.npz') - if os.path.exists(path_to_item_graph): - self._item_graph = sp.load_npz(path_to_item_graph) - else: - item2item_interactions_fst = [] - item2item_interactions_snd = [] - visited_user_item_pairs = set() - visited_item_item_pairs = set() - - for user_id, item_id in tqdm( - zip( - self._train_user_interactions, - self._train_item_interactions, - ), - ): - if (user_id, item_id) in visited_user_item_pairs: - continue # process (user, item) pair only once - visited_user_item_pairs.add((user_id, item_id)) - - for connected_item_id in train_user_2_items[user_id]: - if ( - (item_id, connected_item_id) - in visited_item_item_pairs - or item_id == connected_item_id - ): - continue # add (item, item) to graph connections pair only once - visited_item_item_pairs.add( - (item_id, connected_item_id), - ) - - item2item_interactions_fst.append(item_id) - item2item_interactions_snd.append(connected_item_id) - - # (item, item) graph - item2item_connections = csr_matrix( - ( - np.ones(len(item2item_interactions_fst)), - ( - item2item_interactions_fst, - item2item_interactions_snd, - ), - ), - shape=(self._num_items + 2, self._num_items + 2), - ) - - if self._neighborhood_size is not None: - item2item_connections = self._filter_matrix_by_top_k(item2item_connections, self._neighborhood_size) - - self._item_graph = self.get_sparse_graph_layer( - item2item_connections, - self._num_items + 2, - self._num_items + 2, - biparite=False, - ) - sp.save_npz(path_to_item_graph, self._item_graph) - - self._item_graph = ( - self._convert_sp_mat_to_sp_tensor(self._item_graph) - .coalesce() - .to(DEVICE) - ) - else: - self._item_graph = None - - @classmethod - def create_from_config(cls, config): - dataset = BaseDataset.create_from_config(config['dataset']) - return cls( - dataset=dataset, - graph_dir_path=config['graph_dir_path'], - use_user_graph=config.get('use_user_graph', False), - use_item_graph=config.get('use_item_graph', False), - neighborhood_size=config.get('neighborhood_size', None), - ) - - @staticmethod - def get_sparse_graph_layer( - sparse_matrix, - fst_dim, - snd_dim, - biparite=False, - ): - if not biparite: - adj_mat = sparse_matrix.tocsr() - else: - R = sparse_matrix.tocsr() - - upper_right = R - lower_left = R.T - - upper_left = sp.csr_matrix((fst_dim, fst_dim)) - lower_right = sp.csr_matrix((snd_dim, snd_dim)) - - adj_mat = sp.bmat([ - [upper_left, upper_right], - [lower_left, lower_right] - ]) - assert adj_mat.shape == (fst_dim + snd_dim, fst_dim + snd_dim), ( - f"Got shape {adj_mat.shape}, expected {(fst_dim+snd_dim, fst_dim+snd_dim)}" - ) - - rowsum = np.array(adj_mat.sum(1)) - d_inv = np.power(rowsum, -0.5).flatten() - d_inv[np.isinf(d_inv)] = 0. - d_mat_inv = sp.diags(d_inv) - - norm_adj = d_mat_inv.dot(adj_mat).dot(d_mat_inv) - return norm_adj.tocsr() - - @staticmethod - def _convert_sp_mat_to_sp_tensor(X): - coo = X.tocoo().astype(np.float32) - row = torch.Tensor(coo.row).long() - col = torch.Tensor(coo.col).long() - index = torch.stack([row, col]) - data = torch.FloatTensor(coo.data) - return torch.sparse.FloatTensor(index, data, torch.Size(coo.shape)) - - @staticmethod - def _filter_matrix_by_top_k(matrix, k): - mat = matrix.tolil() - - for i in range(mat.shape[0]): - if len(mat.rows[i]) <= k: - continue - data = np.array(mat.data[i]) - - top_k_indices = np.argpartition(data, -k)[-k:] - mat.data[i] = [mat.data[i][j] for j in top_k_indices] - mat.rows[i] = [mat.rows[i][j] for j in top_k_indices] - - return mat.tocsr() - - - @property - def num_users(self): - return self._dataset.num_users - - @property - def num_items(self): - return self._dataset.num_items - - def get_samplers(self): - return self._dataset.get_samplers() - - @property - def meta(self): - meta = { - 'user_graph': self._user_graph, - 'item_graph': self._item_graph, - 'graph': self._graph, - **self._dataset.meta, - } - return meta - - -class DuorecDataset(BaseDataset, config_name='duorec'): - def __init__(self, dataset): - self._dataset = dataset - self._num_users = dataset.num_users - self._num_items = dataset.num_items - - train_sampler, _, _ = self._dataset.get_samplers() - - target_2_sequences = defaultdict(list) - for sample in train_sampler.dataset: - item_ids = sample['item.ids'] - - target_item = item_ids[-1] - semantic_similar_item_ids = item_ids[:-1] - - target_2_sequences[target_item].append(semantic_similar_item_ids) - - train_sampler._target_2_sequences = target_2_sequences - - @classmethod - def create_from_config(cls, config): - dataset = BaseDataset.create_from_config(config['dataset']) - return cls(dataset) - - @property - def num_users(self): - return self._dataset.num_users - - @property - def num_items(self): - return self._dataset.num_items - - def get_samplers(self): - return self._dataset.get_samplers() - - @property - def meta(self): - return self._dataset.meta - - -class ScientificDataset(BaseDataset, config_name='scientific'): - def __init__( - self, - train_sampler, - validation_sampler, - test_sampler, - num_users, - num_items, - max_sequence_length, - ): - self._train_sampler = train_sampler - self._validation_sampler = validation_sampler - self._test_sampler = test_sampler - self._num_users = num_users - self._num_items = num_items - self._max_sequence_length = max_sequence_length - - @classmethod - def create_from_config(cls, config, **kwargs): - data_dir_path = os.path.join( - config['path_to_data_dir'], - config['name'], - ) - max_sequence_length = config['max_sequence_length'] - max_user_id, max_item_id = 0, 0 - train_dataset, validation_dataset, test_dataset = [], [], [] - - dataset_path = os.path.join(data_dir_path, '{}.txt'.format('all_data')) - with open(dataset_path, 'r') as f: - data = f.readlines() - - for sample in data: - sample = sample.strip('\n').split(' ') - user_id = int(sample[0]) - item_ids = [int(item_id) for item_id in sample[1:]] - - max_user_id = max(max_user_id, user_id) - max_item_id = max(max_item_id, max(item_ids)) - - assert len(item_ids) >= 5 - - train_dataset.append( - { - 'user.ids': [user_id], - 'user.length': 1, - 'item.ids': item_ids[:-2][-max_sequence_length:], - 'item.length': len(item_ids[:-2][-max_sequence_length:]), - }, - ) - assert len(item_ids[:-2][-max_sequence_length:]) == len( - set(item_ids[:-2][-max_sequence_length:]), - ) - validation_dataset.append( - { - 'user.ids': [user_id], - 'user.length': 1, - 'item.ids': item_ids[:-1][-max_sequence_length:], - 'item.length': len(item_ids[:-1][-max_sequence_length:]), - }, - ) - assert len(item_ids[:-1][-max_sequence_length:]) == len( - set(item_ids[:-1][-max_sequence_length:]), - ) - test_dataset.append( - { - 'user.ids': [user_id], - 'user.length': 1, - 'item.ids': item_ids[-max_sequence_length:], - 'item.length': len(item_ids[-max_sequence_length:]), - }, - ) - assert len(item_ids[-max_sequence_length:]) == len( - set(item_ids[-max_sequence_length:]), - ) - - logger.info('Train dataset size: {}'.format(len(train_dataset))) - logger.info('Test dataset size: {}'.format(len(test_dataset))) - logger.info('Max user id: {}'.format(max_user_id)) - logger.info('Max item id: {}'.format(max_item_id)) - logger.info('Max sequence length: {}'.format(max_sequence_length)) - logger.info( - '{} dataset sparsity: {}'.format( - config['name'], - (len(train_dataset) + len(test_dataset)) - / max_user_id - / max_item_id, - ), - ) - - train_sampler = TrainSampler.create_from_config( - config['samplers'], - dataset=train_dataset, - num_users=max_user_id, - num_items=max_item_id, - ) - validation_sampler = EvalSampler.create_from_config( - config['samplers'], - dataset=validation_dataset, - num_users=max_user_id, - num_items=max_item_id, - ) - test_sampler = EvalSampler.create_from_config( - config['samplers'], - dataset=test_dataset, - num_users=max_user_id, - num_items=max_item_id, - ) - - return cls( - train_sampler=train_sampler, - validation_sampler=validation_sampler, - test_sampler=test_sampler, - num_users=max_user_id, - num_items=max_item_id, - max_sequence_length=max_sequence_length, - ) - - def get_samplers(self): - return ( - self._train_sampler, - self._validation_sampler, - self._test_sampler, - ) - - @property - def num_users(self): - return self._num_users - - @property - def num_items(self): - return self._num_items - - @property - def max_sequence_length(self): - return self._max_sequence_length - - @property - def meta(self): - return { - 'num_users': self.num_users, - 'num_items': self.num_items, - 'max_sequence_length': self.max_sequence_length, - } - - -class MCLSRDataset(BaseDataset, config_name='mclsr'): - def __init__(self, train_sampler, validation_sampler, test_sampler, num_users, num_items, max_sequence_length): - self._train_sampler = train_sampler - self._validation_sampler = validation_sampler - self._test_sampler = test_sampler - self._num_users = num_users - self._num_items = num_items - self._max_sequence_length = max_sequence_length - - @staticmethod - def _create_sequences_from_file(filepath, max_len=None): - sequences = {} - max_user, max_item = 0, 0 - - with open(filepath, 'r') as f: - for line in f: - parts = line.strip().split(' ') - user_id = int(parts[0]) - item_ids = [int(i) for i in parts[1:]] - if max_len: - item_ids = item_ids[-max_len:] - sequences[user_id] = item_ids - max_user = max(max_user, user_id) - if item_ids: - max_item = max(max_item, max(item_ids)) - return sequences, max_user, max_item - - @classmethod - def _create_evaluation_sets(cls, data_dir, max_seq_len): - valid_hist, u2, i2 = cls._create_sequences_from_file(os.path.join(data_dir, 'valid_history.txt'), max_seq_len) - valid_trg, u3, i3 = cls._create_sequences_from_file(os.path.join(data_dir, 'valid_target.txt')) - - validation_dataset = [{'user.ids': [uid], 'history': valid_hist[uid], 'target': valid_trg[uid]} for uid in valid_hist if uid in valid_trg] - - test_hist, u4, i4 = cls._create_sequences_from_file(os.path.join(data_dir, 'test_history.txt'), max_seq_len) - test_trg, u5, i5 = cls._create_sequences_from_file(os.path.join(data_dir, 'test_target.txt')) - - test_dataset = [{'user.ids': [uid], 'history': test_hist[uid], 'target': test_trg[uid]} for uid in test_hist if uid in test_trg] - - return validation_dataset, test_dataset, max(u2, u3, u4, u5), max(i2, i3, i4, i5) - - @classmethod - def create_from_config(cls, config, **kwargs): - data_dir = os.path.join(config['path_to_data_dir'], config['name']) - max_seq_len = config.get('max_sequence_length') - - train_sequences, u1, i1 = cls._create_sequences_from_file(os.path.join(data_dir, 'train_mclsr.txt'), max_seq_len) - train_dataset = [{'user.ids': [uid], 'user.length': 1, 'item.ids': seq, 'item.length': len(seq)} for uid, seq in train_sequences.items()] - - user_to_all_seen_items = defaultdict(set) - for sample in train_dataset: user_to_all_seen_items[sample['user.ids'][0]].update(sample['item.ids']) - kwargs['user_to_all_seen_items'] = user_to_all_seen_items - - validation_dataset, test_dataset, u_eval, i_eval = cls._create_evaluation_sets(data_dir, max_seq_len) - num_users = max(u1, u_eval) - num_items = max(i1, i_eval) - - train_sampler = TrainSampler.create_from_config(config['samplers'], dataset=train_dataset, num_users=num_users, num_items=num_items, **kwargs) - validation_sampler = EvalSampler.create_from_config(config['samplers'], dataset=validation_dataset, num_users=num_users, num_items=num_items, **kwargs) - test_sampler = EvalSampler.create_from_config(config['samplers'], dataset=test_dataset, num_users=num_users, num_items=num_items, **kwargs) - - return cls(train_sampler, validation_sampler, test_sampler, num_users, num_items, max_seq_len) - - def get_samplers(self): - return (self._train_sampler, self._validation_sampler, self._test_sampler) - - @property - def num_users(self): - return self._num_users - - @property - def num_items(self): - return self._num_items - - @property - def meta(self): - return {'num_users': self.num_users, 'num_items': self.num_items, 'max_sequence_length': self._max_sequence_length} - -class SASRecDataset(BaseDataset, config_name='sasrec_comparison'): - def __init__(self, train_sampler, validation_sampler, test_sampler, num_users, num_items, max_sequence_length): - self._train_sampler = train_sampler - self._validation_sampler = validation_sampler - self._test_sampler = test_sampler - self._num_users = num_users - self._num_items = num_items - self._max_sequence_length = max_sequence_length - - @classmethod - def create_from_config(cls, config, **kwargs): - data_dir = os.path.join(config['path_to_data_dir'], config['name']) - max_seq_len = config.get('max_sequence_length') - - train_dataset, u1, i1, _ = SequenceDataset._create_dataset( - dir_path=data_dir, - part='train_sasrec', - max_sequence_length=max_seq_len - ) - - validation_dataset, test_dataset, u_eval, i_eval = MCLSRDataset._create_evaluation_sets(data_dir, max_seq_len) - - num_users = max(u1, u_eval) - num_items = max(i1, i_eval) - train_sampler = TrainSampler.create_from_config( - config['train_sampler'], - dataset=train_dataset, num_users=num_users, num_items=num_items - ) - - validation_sampler = EvalSampler.create_from_config( - config['eval_sampler'], - dataset=validation_dataset, num_users=num_users, num_items=num_items - ) - test_sampler = EvalSampler.create_from_config( - config['eval_sampler'], - dataset=test_dataset, num_users=num_users, num_items=num_items - ) - - return cls(train_sampler, validation_sampler, test_sampler, num_users, num_items, max_seq_len) - - def get_samplers(self): - return (self._train_sampler, self._validation_sampler, self._test_sampler) - - @property - def num_users(self): return self._num_users - @property - def num_items(self): return self._num_items - @property - def meta(self): - return {'num_users': self.num_users, 'num_items': self.num_items, 'max_sequence_length': self._max_sequence_length} \ No newline at end of file diff --git a/src/irec/dataset/negative_samplers/__init__.py b/src/irec/dataset/negative_samplers/__init__.py deleted file mode 100644 index 04b82f92..00000000 --- a/src/irec/dataset/negative_samplers/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .base import BaseNegativeSampler -from .popular import PopularNegativeSampler -from .random import RandomNegativeSampler - -__all__ = [ - 'BaseNegativeSampler', - 'PopularNegativeSampler', - 'RandomNegativeSampler', -] diff --git a/src/irec/dataset/negative_samplers/base.py b/src/irec/dataset/negative_samplers/base.py deleted file mode 100644 index b4d12248..00000000 --- a/src/irec/dataset/negative_samplers/base.py +++ /dev/null @@ -1,19 +0,0 @@ -from collections import defaultdict - -from irec.utils import MetaParent - - -class BaseNegativeSampler(metaclass=MetaParent): - def __init__(self, dataset, num_users, num_items): - self._dataset = dataset - self._num_users = num_users - self._num_items = num_items - - self._seen_items = defaultdict(set) - for sample in self._dataset: - user_id = sample['user.ids'][0] - items = list(sample['item.ids']) - self._seen_items[user_id].update(items) - - def generate_negative_samples(self, sample, num_negatives): - raise NotImplementedError diff --git a/src/irec/dataset/negative_samplers/popular.py b/src/irec/dataset/negative_samplers/popular.py deleted file mode 100644 index ca91bf6c..00000000 --- a/src/irec/dataset/negative_samplers/popular.py +++ /dev/null @@ -1,44 +0,0 @@ -from irec.dataset.negative_samplers.base import BaseNegativeSampler - -from collections import Counter - - -class PopularNegativeSampler(BaseNegativeSampler, config_name='popular'): - def __init__(self, dataset, num_users, num_items): - super().__init__( - dataset=dataset, - num_users=num_users, - num_items=num_items, - ) - - self._popular_items = self._items_by_popularity() - - @classmethod - def create_from_config(cls, _, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) - - def _items_by_popularity(self): - popularity = Counter() - - for sample in self._dataset: - for item_id in sample['item.ids']: - popularity[item_id] += 1 - - popular_items = sorted(popularity, key=popularity.get, reverse=True) - return popular_items - - def generate_negative_samples(self, sample, num_negatives): - user_id = sample['user.ids'][0] - popularity_idx = 0 - negatives = [] - while len(negatives) < num_negatives: - negative_idx = self._popular_items[popularity_idx] - if negative_idx not in self._seen_items[user_id]: - negatives.append(negative_idx) - popularity_idx += 1 - - return negatives diff --git a/src/irec/dataset/negative_samplers/random.py b/src/irec/dataset/negative_samplers/random.py deleted file mode 100644 index 79e245a6..00000000 --- a/src/irec/dataset/negative_samplers/random.py +++ /dev/null @@ -1,28 +0,0 @@ -from irec.dataset.negative_samplers.base import BaseNegativeSampler - -import numpy as np - - -class RandomNegativeSampler(BaseNegativeSampler, config_name='random'): - @classmethod - def create_from_config(cls, _, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) - - def generate_negative_samples(self, sample, num_negatives): - user_id = sample['user.ids'][0] - all_items = list(range(1, self._num_items + 1)) - np.random.shuffle(all_items) - - negatives = [] - running_idx = 0 - while len(negatives) < num_negatives and running_idx < len(all_items): - negative_idx = all_items[running_idx] - if negative_idx not in self._seen_items[user_id]: - negatives.append(negative_idx) - running_idx += 1 - - return negatives diff --git a/src/irec/dataset/samplers/__init__.py b/src/irec/dataset/samplers/__init__.py deleted file mode 100644 index 88198954..00000000 --- a/src/irec/dataset/samplers/__init__.py +++ /dev/null @@ -1,40 +0,0 @@ -from .base import TrainSampler, EvalSampler -from .cl4srec import Cl4SRecTrainSampler, Cl4SRecEvalSampler -from .duorec import DuorecTrainSampler, DuoRecEvalSampler -from .next_item_prediction import ( - NextItemPredictionTrainSampler, - NextItemPredictionEvalSampler, -) -from .last_item_prediction import ( - LastItemPredictionTrainSampler, - LastItemPredictionEvalSampler, -) -from .masked_item_prediction import ( - MaskedItemPredictionTrainSampler, - MaskedItemPredictionEvalSampler, -) -from .mclsr import MCLSRTrainSampler, MCLSRPredictionEvalSampler -from .pop import PopTrainSampler, PopEvalSampler -from .s3rec import S3RecPretrainTrainSampler, S3RecPretrainEvalSampler - - -__all__ = [ - 'TrainSampler', - 'EvalSampler', - 'Cl4SRecTrainSampler', - 'Cl4SRecEvalSampler', - 'DuorecTrainSampler', - 'DuoRecEvalSampler', - 'NextItemPredictionTrainSampler', - 'NextItemPredictionEvalSampler', - 'LastItemPredictionTrainSampler', - 'LastItemPredictionEvalSampler', - 'MaskedItemPredictionTrainSampler', - 'MaskedItemPredictionEvalSampler', - 'MCLSRTrainSampler', - 'MCLSRPredictionEvalSampler', - 'PopTrainSampler', - 'PopEvalSampler', - 'S3RecPretrainTrainSampler', - 'S3RecPretrainEvalSampler', -] diff --git a/src/irec/dataset/samplers/base.py b/src/irec/dataset/samplers/base.py deleted file mode 100644 index 158bede4..00000000 --- a/src/irec/dataset/samplers/base.py +++ /dev/null @@ -1,48 +0,0 @@ -from irec.utils import MetaParent - -import copy - - -class TrainSampler(metaclass=MetaParent): - def __init__(self): - self._dataset = None - - @property - def dataset(self): - return self._dataset - - def __len__(self): - return len(self._dataset) - - def __getitem__(self, index): - raise NotImplementedError - - -class EvalSampler(metaclass=MetaParent): - def __init__(self, dataset, num_users, num_items): - super().__init__() - self._dataset = dataset - self._num_users = num_users - self._num_items = num_items - - @property - def dataset(self): - return self._dataset - - def __len__(self): - return len(self._dataset) - - def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) - - item_sequence = sample['item.ids'][:-1] - next_item = sample['item.ids'][-1] - - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'labels.ids': [next_item], - 'labels.length': 1, - } diff --git a/src/irec/dataset/samplers/cl4srec.py b/src/irec/dataset/samplers/cl4srec.py deleted file mode 100644 index f64385b5..00000000 --- a/src/irec/dataset/samplers/cl4srec.py +++ /dev/null @@ -1,150 +0,0 @@ -import numpy as np - -from irec.dataset.samplers.base import TrainSampler, EvalSampler - -import copy - - -class Cl4SRecTrainSampler(TrainSampler, config_name='cl4srec'): - def __init__( - self, - dataset, - num_users, - num_items, - item_crop_portion, - item_mask_portion, - item_reorder_portion, - ): - super().__init__() - self._dataset = dataset - self._num_users = num_users - self._num_items = num_items - self._mask_item_idx = self._num_items + 1 - self._item_crop_portion = item_crop_portion - self._item_mask_portion = item_mask_portion - self._item_reorder_portion = item_reorder_portion - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - item_crop_portion=config['item_crop_portion'], - item_mask_portion=config['item_mask_portion'], - item_reorder_portion=config['item_reorder_portion'], - ) - - def _apply_crop_augmentation(self, item_sequence): - num_elements_to_crop = max( - 1, - int(self._item_crop_portion * len(item_sequence)), - ) - crop_start_index = np.random.randint( - low=0, - high=len(item_sequence) - num_elements_to_crop + 1, - ) - assert ( - 0 <= crop_start_index <= len(item_sequence) - num_elements_to_crop - ) - item_sequence = item_sequence[ - crop_start_index : crop_start_index + num_elements_to_crop - ] - return item_sequence - - def _apply_mask_augmentation(self, item_sequence): - for idx in range(len(item_sequence)): - p = np.random.uniform(low=0.0, high=1.0) - if p < self._item_mask_portion: - item_sequence[idx] = self._mask_item_idx - - if p < self._item_mask_portion: - p /= self._item_mask_portion - - if p < 0.8: - item_sequence[idx] = self._mask_item_idx - elif p < 0.9: - item_sequence[idx] = np.random.randint( - 1, - self._num_items + 1, - ) - else: - pass # item_sequence[idx] = item_sequence[idx] - else: - pass # item_sequence[idx] = item_sequence[idx] - - return item_sequence - - def _apply_reorder_augmentation(self, item_sequence): - num_elements_to_reorder = int( - self._item_reorder_portion * len(item_sequence), - ) - reorder_start_index = np.random.randint( - low=0, - high=len(item_sequence) - num_elements_to_reorder + 1, - ) - assert ( - 0 - <= reorder_start_index - <= len(item_sequence) - num_elements_to_reorder - ) - reordered_subsequence = item_sequence[ - reorder_start_index : reorder_start_index + num_elements_to_reorder - ] - np.random.shuffle(reordered_subsequence) # This works in-place - item_sequence = ( - item_sequence[:reorder_start_index] - + reordered_subsequence - + item_sequence[reorder_start_index + num_elements_to_reorder :] - ) - return item_sequence - - def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) - - item_sequence = sample['item.ids'][:-1] - next_item = sample['item.ids'][-1] - - sample_items = set(sample['item.ids']) - negative = np.random.randint(low=1, high=self._num_items + 1) - while negative in sample_items: - negative = np.random.randint(low=1, high=self._num_items + 1) - - augmentation_list = [ - self._apply_crop_augmentation, - self._apply_mask_augmentation, - self._apply_reorder_augmentation, - ] - - fst_augmentation = np.random.choice(augmentation_list) - snd_augmentation = np.random.choice(augmentation_list) - - fst_augmented_sequence = fst_augmentation(item_sequence) - snd_augmented_sequence = snd_augmentation(item_sequence) - - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'fst_augmented_item.ids': fst_augmented_sequence, - 'fst_augmented_item.length': len(fst_augmented_sequence), - 'snd_augmented_item.ids': snd_augmented_sequence, - 'snd_augmented_item.length': len(snd_augmented_sequence), - 'labels.ids': [next_item], - 'labels.length': 1, - 'positive.ids': [next_item], - 'positive.length': 1, - 'negative.ids': [negative], - 'negative.length': 1, - } - - -class Cl4SRecEvalSampler(EvalSampler, config_name='cl4srec'): - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) diff --git a/src/irec/dataset/samplers/duorec.py b/src/irec/dataset/samplers/duorec.py deleted file mode 100644 index e47fe2ed..00000000 --- a/src/irec/dataset/samplers/duorec.py +++ /dev/null @@ -1,55 +0,0 @@ -import random - -from irec.dataset.samplers.base import TrainSampler, EvalSampler - -import copy - - -class DuorecTrainSampler(TrainSampler, config_name='duorec'): - def __init__(self, dataset, num_users, num_items): - super().__init__() - self._dataset = dataset - self._num_users = num_users - self._num_items = num_items - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) - - def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) - - item_sequence = sample['item.ids'] - - target_item = item_sequence[-1] - item_sequence = item_sequence[:-1] - - # There is a probability of sampling the same sequence - semantic_similar_sequence = random.choice( - self._target_2_sequences[target_item], - ) - - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'labels.ids': [target_item], - 'labels.length': 1, - 'semantic_similar_item.ids': semantic_similar_sequence, - 'semantic_similar_item.length': len(semantic_similar_sequence), - } - - -class DuoRecEvalSampler(EvalSampler, config_name='duorec'): - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) diff --git a/src/irec/dataset/samplers/last_item_prediction.py b/src/irec/dataset/samplers/last_item_prediction.py deleted file mode 100644 index 20f2f3fd..00000000 --- a/src/irec/dataset/samplers/last_item_prediction.py +++ /dev/null @@ -1,50 +0,0 @@ -from irec.dataset.samplers.base import TrainSampler, EvalSampler - -import copy - - -class LastItemPredictionTrainSampler( - TrainSampler, - config_name='last_item_prediction', -): - def __init__(self, dataset, num_users, num_items): - super().__init__() - self._dataset = dataset - self._num_users = num_users - self._num_items = num_items - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) - - def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) - - item_sequence = sample['item.ids'][:-1] - last_item = sample['item.ids'][-1] - - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'labels.ids': [last_item], - 'labels.length': 1, - } - - -class LastItemPredictionEvalSampler( - EvalSampler, - config_name='last_item_prediction', -): - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) diff --git a/src/irec/dataset/samplers/masked_item_prediction.py b/src/irec/dataset/samplers/masked_item_prediction.py deleted file mode 100644 index a78657e8..00000000 --- a/src/irec/dataset/samplers/masked_item_prediction.py +++ /dev/null @@ -1,99 +0,0 @@ -from irec.dataset.samplers.base import TrainSampler, EvalSampler - -import copy -import numpy as np - - -class MaskedItemPredictionTrainSampler( - TrainSampler, - config_name='masked_item_prediction', -): - def __init__(self, dataset, num_users, num_items, mask_prob=0.0): - super().__init__() - self._dataset = dataset - self._num_users = num_users - self._num_items = num_items - self._mask_item_idx = self._num_items + 1 - self._mask_prob = mask_prob - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - mask_prob=config.get('mask_prob', 0.0), - ) - - def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) - - item_sequence = sample['item.ids'] - - masked_sequence = [] - labels = [] - - for item in item_sequence: - prob = np.random.uniform(low=0.0, high=1.0) - - if prob < self._mask_prob: - prob /= self._mask_prob - - if prob < 0.8: - masked_sequence.append(self._mask_item_idx) - elif prob < 0.9: - masked_sequence.append( - np.random.randint(1, self._num_items + 1), - ) - else: - masked_sequence.append(item) - - labels.append(item) - else: - masked_sequence.append(item) - labels.append(0) - - # Mask last item - masked_sequence[-1] = self._mask_item_idx - labels[-1] = item_sequence[-1] - - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': masked_sequence, - 'item.length': len(masked_sequence), - 'labels.ids': labels, - 'labels.length': len(labels), - } - - -class MaskedItemPredictionEvalSampler( - EvalSampler, - config_name='masked_item_prediction', -): - def __init__(self, dataset, num_users, num_items): - super().__init__(dataset, num_users, num_items) - self._mask_item_idx = self._num_items + 1 - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) - - def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) - item_sequence = sample['item.ids'] - labels = [item_sequence[-1]] - sequence = item_sequence[:-1] + [self._mask_item_idx] - - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': sequence, - 'item.length': len(sequence), - 'labels.ids': labels, - 'labels.length': len(labels), - } diff --git a/src/irec/dataset/samplers/mclsr.py b/src/irec/dataset/samplers/mclsr.py deleted file mode 100644 index b957cde7..00000000 --- a/src/irec/dataset/samplers/mclsr.py +++ /dev/null @@ -1,80 +0,0 @@ -from irec.dataset.samplers.base import TrainSampler, EvalSampler - -from collections import defaultdict -import random - - -class MCLSRTrainSampler(TrainSampler, config_name='mclsr'): - def __init__(self, dataset, num_users, num_items, user_to_all_seen_items, num_negatives, **kwargs): - super().__init__() - self._dataset = dataset - self._num_users = num_users - self._num_items = num_items - self._num_negatives = num_negatives - self._all_items_set = set(range(1, num_items + 1)) - self._user_to_all_seen_items = user_to_all_seen_items - - @classmethod - def create_from_config(cls, config, **kwargs): - num_negatives = config['num_negatives_train'] - print(num_negatives) - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - num_negatives=num_negatives, - user_to_all_seen_items=kwargs['user_to_all_seen_items'], - ) - - - def __getitem__(self, index): - sample = self._dataset[index] - - user_id = sample['user.ids'][0] - item_sequence = sample['item.ids'][:-1] - positive_item = sample['item.ids'][-1] - - user_seen = self._user_to_all_seen_items[user_id] - - unseen_items = list(self._all_items_set - user_seen) - - negatives = random.sample(unseen_items, self._num_negatives) - - - return { - 'user.ids': [user_id], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'labels.ids': [positive_item], - 'labels.length': 1, - 'negatives.ids': negatives, - 'negatives.length': len(negatives), - } - - -class MCLSRPredictionEvalSampler(EvalSampler, config_name='mclsr'): - def __init__(self, dataset, num_users, num_items): - super().__init__(dataset, num_users, num_items) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) - - def __getitem__(self, index): - sample = self._dataset[index] - history_sequence = sample['history'] - target_items = sample['target'] - - return { - 'user.ids': sample['user.ids'], - 'user.length': 1, - 'item.ids': history_sequence, - 'item.length': len(history_sequence), - 'labels.ids': target_items, - 'labels.length': len(target_items), - } diff --git a/src/irec/dataset/samplers/next_item_prediction.py b/src/irec/dataset/samplers/next_item_prediction.py deleted file mode 100644 index e263062f..00000000 --- a/src/irec/dataset/samplers/next_item_prediction.py +++ /dev/null @@ -1,86 +0,0 @@ -from irec.dataset.samplers.base import TrainSampler, EvalSampler -from irec.dataset.negative_samplers.base import BaseNegativeSampler - -import copy - - -class NextItemPredictionTrainSampler( - TrainSampler, - config_name='next_item_prediction', -): - def __init__( - self, - dataset, - num_users, - num_items, - negative_sampler, - num_negatives=0, - ): - super().__init__() - self._dataset = dataset - self._num_users = num_users - self._num_items = num_items - self._negative_sampler = negative_sampler - self._num_negatives = num_negatives - - @classmethod - def create_from_config(cls, config, **kwargs): - negative_sampler = BaseNegativeSampler.create_from_config( - {'type': config['negative_sampler_type']}, - **kwargs, - ) - - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - negative_sampler=negative_sampler, - num_negatives=config.get('num_negatives_train', 0), - ) - - def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) - - item_sequence = sample['item.ids'][:-1] - next_item_sequence = sample['item.ids'][1:] - - if self._num_negatives == 0: - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'positive.ids': next_item_sequence, - 'positive.length': len(next_item_sequence), - } - else: - negative_sequence = ( - self._negative_sampler.generate_negative_samples( - sample, - self._num_negatives, - ) - ) - - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': item_sequence, - 'item.length': len(item_sequence), - 'positive.ids': next_item_sequence, - 'positive.length': len(next_item_sequence), - 'negative.ids': negative_sequence, - 'negative.length': len(negative_sequence), - } - - -class NextItemPredictionEvalSampler( - EvalSampler, - config_name='next_item_prediction', -): - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) diff --git a/src/irec/dataset/samplers/pop.py b/src/irec/dataset/samplers/pop.py deleted file mode 100644 index 7d4c01b7..00000000 --- a/src/irec/dataset/samplers/pop.py +++ /dev/null @@ -1,65 +0,0 @@ -from irec.dataset.samplers.base import TrainSampler, EvalSampler - -import copy - -from collections import Counter - - -CANDIDATE_COUNTS = None - - -class PopTrainSampler(TrainSampler, config_name='pop'): - def __init__(self, dataset, num_items): - super().__init__() - - global CANDIDATE_COUNTS - - if CANDIDATE_COUNTS is None: - item_2_count = Counter() - - for sample in dataset: - items = sample['item.ids'] - for item in items: - item_2_count[item] += 1 - - CANDIDATE_COUNTS = ( - [0] - + [ - self._item_2_count[item_id] - for item_id in range(1, self._num_items + 1) - ] - + [0] - ) # Mask + items + padding - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls(dataset=kwargs['dataset'], num_items=kwargs['num_items']) - - -class PopEvalSampler(EvalSampler, config_name='pop'): - def __init__(self, dataset, num_users, num_items): - super().__init__(dataset, num_users, num_items) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) - - def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) - labels = [sample['item.ids'][-1]] - - global CANDIDATE_COUNTS - assert CANDIDATE_COUNTS is not None - - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'labels.ids': labels, - 'labels.length': len(labels), - 'candidates_counts.ids': CANDIDATE_COUNTS, - 'candidates_counts.length': len(CANDIDATE_COUNTS), - } diff --git a/src/irec/dataset/samplers/s3rec.py b/src/irec/dataset/samplers/s3rec.py deleted file mode 100644 index 0cff74a4..00000000 --- a/src/irec/dataset/samplers/s3rec.py +++ /dev/null @@ -1,131 +0,0 @@ -from irec.dataset.samplers.base import TrainSampler, EvalSampler -from irec.dataset.negative_samplers.base import BaseNegativeSampler - -import copy -import numpy as np - - -class S3RecPretrainTrainSampler(TrainSampler, config_name='s3rec_pretrain'): - def __init__( - self, - dataset, - num_users, - num_items, - negative_sampler, - mask_prob=0.0, - ): - super().__init__() - self._dataset = dataset - self._num_users = num_users - self._num_items = num_items - self._mask_item_idx = self._num_items + 1 - self._mask_prob = mask_prob - self._negative_sampler = negative_sampler - - self._long_sequence = [] - for sample in self._dataset: - self._long_sequence.extend(copy.deepcopy(sample['item.ids'])) - - @classmethod - def create_from_config(cls, config, **kwargs): - negative_sampler = BaseNegativeSampler.create_from_config( - {'type': config['negative_sampler_type']}, - **kwargs, - ) - - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - negative_sampler=negative_sampler, - mask_prob=config.get('mask_prob', 0.0), - ) - - def __getitem__(self, index): - sample = copy.deepcopy(self._dataset[index]) - - item_sequence = sample['item.ids'] - - if len(item_sequence) < 3: - assert False, 'Something strange is happening' - - # Masked Item Prediction - masked_sequence = [] - positive_sequence = [] - negative_sequence = [] - - for item in item_sequence: - prob = np.random.rand() - - if prob < self._mask_prob: - masked_sequence.append(self._mask_item_idx) - positive_sequence.append(item) - negative_sequence.append( - self._negative_sampler.generate_negative_samples( - sample, - 1, - )[0], - ) - - else: - masked_sequence.append(item) - positive_sequence.append(0) - negative_sequence.append(self._mask_item_idx) - - # Mask last item - masked_sequence[-1] = self._mask_item_idx - positive_sequence[-1] = item_sequence[-1] - negative_sequence[-1] = ( - self._negative_sampler.generate_negative_samples(sample, 1)[0] - ) - assert ( - len(positive_sequence) - == len(negative_sequence) - == len(masked_sequence) - == len(item_sequence) - ) - - # Segment Prediction - sample_length = np.random.randint(1, (len(item_sequence) + 1) // 2) - start_id = np.random.randint(0, len(item_sequence) - sample_length) - negative_start_id = np.random.randint( - 0, - len(self._long_sequence) - sample_length, - ) - masked_segment_sequence = ( - item_sequence[:start_id] - + [self._mask_item_idx] * sample_length - + item_sequence[start_id + sample_length :] - ) - positive_segment = item_sequence[start_id : start_id + sample_length] - negative_segment = self._long_sequence[ - negative_start_id : negative_start_id + sample_length - ] - assert len(positive_segment) == len(negative_segment) == sample_length - - return { - 'user.ids': sample['user.ids'], - 'user.length': sample['user.length'], - 'item.ids': masked_sequence, - 'item.length': len(masked_sequence), - 'positive.ids': item_sequence, - 'positive.length': len(item_sequence), - 'negative.ids': negative_sequence, - 'negative.length': len(negative_sequence), - 'item_segment.ids': masked_segment_sequence, - 'item_segment.length': len(masked_segment_sequence), - 'positive_segment.ids': positive_segment, - 'positive_segment.length': len(positive_segment), - 'negative_segment.ids': negative_segment, - 'negative_segment.length': len(negative_segment), - } - - -class S3RecPretrainEvalSampler(EvalSampler, config_name='s3rec_pretrain'): - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - dataset=kwargs['dataset'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - ) diff --git a/src/irec/infer.py b/src/irec/infer.py deleted file mode 100644 index 984d94ec..00000000 --- a/src/irec/infer.py +++ /dev/null @@ -1,101 +0,0 @@ -from irec.utils import ( - parse_args, - create_logger, - fix_random_seed, - DEVICE, - ensure_checkpoints_dir, -) - -from irec.dataset import BaseDataset -from irec.dataloader import BaseDataloader -from irec.models import BaseModel, TorchModel -from irec.metric import BaseMetric, StatefullMetric - -import json -import numpy as np -import torch - - -logger = create_logger(name=__name__) -seed_val = 42 - - -def inference(dataloader, model, metrics, pred_prefix, labels_prefix): - running_metrics = {} - for metric_name, metric_function in metrics.items(): - running_metrics[metric_name] = [] - - if isinstance(model, TorchModel): - model.eval() - - with torch.no_grad(): - for idx, batch in enumerate(dataloader): - for key, value in batch.items(): - batch[key] = value.to(DEVICE) - batch[pred_prefix] = model(batch) - - for key, values in batch.items(): - batch[key] = values.cpu() - - for metric_name, metric_function in metrics.items(): - running_metrics[metric_name].extend( - metric_function( - inputs=batch, - pred_prefix=pred_prefix, - labels_prefix=labels_prefix, - ), - ) - - for metric_name, metric_function in metrics.items(): - if isinstance(metric_function, StatefullMetric): - running_metrics[metric_name] = metric_function.reduce( - running_metrics[metric_name], - ) - - logger.debug('Inference procedure has been finished!') - logger.debug('Metrics are the following:') - for metric_name, metric_value in running_metrics.items(): - logger.info('{}: {}'.format(metric_name, np.mean(metric_value))) - - -def main(): - fix_random_seed(seed_val) - config = parse_args() - - logger.debug('Inference config: \n{}'.format(json.dumps(config, indent=2))) - - dataset = BaseDataset.create_from_config(config['dataset']) - - _, _, eval_dataset = dataset.get_samplers() - - eval_dataloader = BaseDataloader.create_from_config( - config['dataloader']['validation'], - dataset=eval_dataset, - ) - - model = BaseModel.create_from_config(config['model'], **dataset.meta) - - if isinstance(model, TorchModel): - model = model.to(DEVICE) - ensure_checkpoints_dir() - checkpoint_path = './checkpoints/{}_final_state.pth'.format( - config['experiment_name'], - ) - model.load_state_dict(torch.load(checkpoint_path)) - - metrics = { - metric_name: BaseMetric.create_from_config(metric_cfg, **dataset.meta) - for metric_name, metric_cfg in config['metrics'].items() - } - - _ = inference( - dataloader=eval_dataloader, - model=model, - metrics=metrics, - pred_prefix=config['pred_prefix'], - labels_prefix=config['label_prefix'], - ) - - -if __name__ == '__main__': - main() diff --git a/src/irec/loss/__init__.py b/src/irec/loss/__init__.py deleted file mode 100644 index 9b5ed21c..00000000 --- a/src/irec/loss/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base import * diff --git a/src/irec/loss/base.py b/src/irec/loss/base.py deleted file mode 100644 index d596c4bd..00000000 --- a/src/irec/loss/base.py +++ /dev/null @@ -1,631 +0,0 @@ -import copy - -from irec.utils import ( - MetaParent, - maybe_to_list, -) - -import torch -import torch.nn as nn - - -class BaseLoss(metaclass=MetaParent): - pass - - -class TorchLoss(BaseLoss, nn.Module): - pass - - -class IdentityLoss(BaseLoss, config_name='identity'): - def __call__(self, inputs): - return inputs - - -class CompositeLoss(TorchLoss, config_name='composite'): - def __init__(self, losses, weights=None, output_prefix=None): - super().__init__() - self._losses = losses - self._weights = weights or [1.0] * len(losses) - self._output_prefix = output_prefix - - @classmethod - def create_from_config(cls, config, **kwargs): - losses = [] - weights = [] - - for loss_cfg in copy.deepcopy(config)['losses']: - weight = loss_cfg.pop('weight') if 'weight' in loss_cfg else 1.0 - loss_function = BaseLoss.create_from_config(loss_cfg) - - weights.append(weight) - losses.append(loss_function) - - return cls( - losses=losses, - weights=weights, - output_prefix=config.get('output_prefix'), - ) - - def forward(self, inputs): - total_loss = 0.0 - for loss, weight in zip(self._losses, self._weights): - total_loss += weight * loss(inputs) - - if self._output_prefix is not None: - inputs[self._output_prefix] = total_loss.cpu().item() - - return total_loss - - -class BatchLogSoftmaxLoss(TorchLoss, config_name='batch_logsoftmax'): - def __init__(self, predictions_prefix, candidates_prefix): - super().__init__() - self._predictions_prefix = predictions_prefix - self._candidates_prefix = candidates_prefix - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - predictions_prefix=config.get('predictions_prefix'), - candidates_prefix=config.get('candidates_prefix'), - ) - - def forward(self, inputs): # use log soft max - predictions = inputs[self._predictions_prefix] - candidates = inputs[self._candidates_prefix] - - dot_product_matrix = predictions @ candidates.T - - row_log_softmax = nn.LogSoftmax(dim=1) - softmax_matrix = -row_log_softmax(dot_product_matrix) - - diagonal_elements = torch.diag(softmax_matrix) - - loss = diagonal_elements.mean() - - return loss - - -class CrossEntropyLoss(TorchLoss, config_name='ce'): - def __init__(self, predictions_prefix, labels_prefix, output_prefix=None): - super().__init__() - self._pred_prefix = predictions_prefix - self._labels_prefix = labels_prefix - self._output_prefix = output_prefix - - self._loss = nn.CrossEntropyLoss() - - def forward(self, inputs): - all_logits = inputs[self._pred_prefix] # (all_items, num_classes) - all_labels = inputs[ - '{}.ids'.format(self._labels_prefix) - ] # (all_items) - assert all_logits.shape[0] == all_labels.shape[0] - - loss = self._loss(all_logits, all_labels) # (1) - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class BinaryCrossEntropyLoss(TorchLoss, config_name='bce'): - def __init__( - self, - predictions_prefix, - labels_prefix, - with_logits=True, - output_prefix=None, - ): - super().__init__() - self._pred_prefix = predictions_prefix - self._labels_prefix = labels_prefix - self._output_prefix = output_prefix - - if with_logits: - self._loss = nn.BCEWithLogitsLoss() - else: - self._loss = nn.BCELoss() - - def forward(self, inputs): - all_logits = inputs[self._pred_prefix].float() # (all_batch_items) - all_labels = inputs[self._labels_prefix].float() # (all_batch_items) - assert all_logits.shape[0] == all_labels.shape[0] - - loss = self._loss(all_logits, all_labels) # (1) - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class BPRLoss(TorchLoss, config_name='bpr'): - def __init__(self, positive_prefix, negative_prefix, output_prefix=None): - super().__init__() - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._output_prefix = output_prefix - - def forward(self, inputs): - pos_scores = inputs[self._positive_prefix] # (all_batch_items) - neg_scores = inputs[self._negative_prefix] # (all_batch_items) - loss = -torch.log( - (pos_scores - neg_scores).sigmoid() + 1e-9, - ).mean() # (1) - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class RegularizationLoss(TorchLoss, config_name='regularization'): - def __init__(self, prefix, output_prefix=None): - super().__init__() - self._prefix = maybe_to_list(prefix) - self._output_prefix = output_prefix - - def forward(self, inputs): - loss = 0.0 - for prefix in self._prefix: - loss += (1 / 2) * inputs[prefix].pow(2).mean() - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class FpsLoss(TorchLoss, config_name='fps'): - def __init__( - self, - fst_embeddings_prefix, - snd_embeddings_prefix, - tau, - normalize_embeddings=False, - use_mean=True, - output_prefix=None, - ): - super().__init__() - self._fst_embeddings_prefix = fst_embeddings_prefix - self._snd_embeddings_prefix = snd_embeddings_prefix - self._tau = tau - self._loss_function = nn.CrossEntropyLoss( - reduction='mean' if use_mean else 'sum', - ) - self._normalize_embeddings = normalize_embeddings - self._output_prefix = output_prefix - print(self._tau) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - fst_embeddings_prefix=config['fst_embeddings_prefix'], - snd_embeddings_prefix=config['snd_embeddings_prefix'], - tau=config.get('temperature', 1.0), - normalize_embeddings=config.get('normalize_embeddings', False), - use_mean=config.get('use_mean', True), - output_prefix=config.get('output_prefix') - ) - - def forward(self, inputs): - fst_embeddings = inputs[ - self._fst_embeddings_prefix - ] # (x, embedding_dim) - snd_embeddings = inputs[ - self._snd_embeddings_prefix - ] # (x, embedding_dim) - - batch_size = fst_embeddings.shape[0] - - combined_embeddings = torch.cat( - (fst_embeddings, snd_embeddings), - dim=0, - ) # (2 * x, embedding_dim) - - if self._normalize_embeddings: - combined_embeddings = torch.nn.functional.normalize( - combined_embeddings, - p=2, - dim=-1, - eps=1e-6, - ) # (2 * x, embedding_dim) - - similarity_scores = ( - torch.mm(combined_embeddings, combined_embeddings.T) / self._tau - ) # (2 * x, 2 * x) - - positive_samples = torch.cat( - ( - torch.diag(similarity_scores, batch_size), - torch.diag(similarity_scores, -batch_size), - ), - dim=0, - ).reshape(2 * batch_size, 1) # (2 * x, 1) - assert torch.allclose( - torch.diag(similarity_scores, batch_size), - torch.diag(similarity_scores, -batch_size), - ) - - mask = torch.ones( - 2 * batch_size, - 2 * batch_size, - dtype=torch.bool, - ) # (2 * x, 2 * x) - mask = mask.fill_diagonal_(0) # Remove equal embeddings scores - for i in range(batch_size): # Remove positives - mask[i, batch_size + i] = 0 - mask[batch_size + i, i] = 0 - - negative_samples = similarity_scores[mask].reshape( - 2 * batch_size, - -1, - ) # (2 * x, 2 * x - 2) - - labels = ( - torch.zeros(2 * batch_size).to(positive_samples.device).long() - ) # (2 * x) - logits = torch.cat( - (positive_samples, negative_samples), - dim=1, - ) # (2 * x, 2 * x - 1) - - loss = self._loss_function(logits, labels) / 2 # (1) - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class SASRecLoss(TorchLoss, config_name='sasrec'): - - def __init__( - self, - positive_prefix, - negative_prefix, - output_prefix=None - ): - super().__init__() - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._output_prefix = output_prefix - - def forward(self, inputs): - positive_scores = inputs[self._positive_prefix] # (x) - negative_scores = inputs[self._negative_prefix] # (x) - assert positive_scores.shape[0] == negative_scores.shape[0] - - loss = torch.nn.functional.binary_cross_entropy_with_logits( - positive_scores, torch.ones_like(positive_scores) - ) + torch.nn.functional.binary_cross_entropy_with_logits( - negative_scores, torch.zeros_like(negative_scores) - ) - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class SamplesSoftmaxLoss(TorchLoss, config_name='sampled_softmax'): - def __init__( - self, - queries_prefix, - positive_prefix, - negative_prefix, - output_prefix=None, - ): - super().__init__() - self._queries_prefix = queries_prefix - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._output_prefix = output_prefix - - def forward(self, inputs): - queries_embeddings = inputs[ - self._queries_prefix - ] # (batch_size, embedding_dim) - positive_embeddings = inputs[ - self._positive_prefix - ] # (batch_size, embedding_dim) - negative_embeddings = inputs[ - self._negative_prefix - ] # (num_negatives, embedding_dim) or (batch_size, num_negatives, embedding_dim) - - # b -- batch_size, d -- embedding_dim - positive_scores = torch.einsum( - 'bd,bd->b', - queries_embeddings, - positive_embeddings, - ).unsqueeze(-1) # (batch_size, 1) - - if negative_embeddings.dim() == 2: # (num_negatives, embedding_dim) - # b -- batch_size, n -- num_negatives, d -- embedding_dim - negative_scores = torch.einsum( - 'bd,nd->bn', - queries_embeddings, - negative_embeddings, - ) # (batch_size, num_negatives) - else: - assert ( - negative_embeddings.dim() == 3 - ) # (batch_size, num_negatives, embedding_dim) - # b -- batch_size, n -- num_negatives, d -- embedding_dim - negative_scores = torch.einsum( - 'bd,bnd->bn', - queries_embeddings, - negative_embeddings, - ) # (batch_size, num_negatives) - all_scores = torch.cat( - [positive_scores, negative_scores], - dim=1, - ) # (batch_size, 1 + num_negatives) - - logits = torch.log_softmax( - all_scores, - dim=1, - ) # (batch_size, 1 + num_negatives) - loss = (-logits)[:, 0] # (batch_size) - loss = loss.mean() # (1) - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class S3RecPretrainLoss(TorchLoss, config_name='s3rec_pretrain'): - def __init__( - self, - positive_prefix, - negative_prefix, - representation_prefix, - output_prefix=None, - ): - super().__init__() - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._representation_prefix = representation_prefix - self._criterion = nn.BCEWithLogitsLoss(reduction='none') - self._output_prefix = output_prefix - - def forward(self, inputs): - positive_embeddings = inputs[ - self._positive_prefix - ] # (x, embedding_dim) - negative_embeddings = inputs[ - self._negative_prefix - ] # (x, embedding_dim) - current_embeddings = inputs[ - self._representation_prefix - ] # (x, embedding_dim) - assert ( - positive_embeddings.shape[0] - == negative_embeddings.shape[0] - == current_embeddings.shape[0] - ) - - positive_scores = torch.einsum( - 'bd,bd->b', - positive_embeddings, - current_embeddings, - ) # (x) - - negative_scores = torch.einsum( - 'bd,bd->b', - negative_embeddings, - current_embeddings, - ) # (x) - - distance = torch.sigmoid(positive_scores) - torch.sigmoid( - negative_scores, - ) # (x) - loss = torch.sum( - self._criterion( - distance, - torch.ones_like(distance, dtype=torch.float32), - ), - ) # (1) - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class Cl4sRecLoss(TorchLoss, config_name='cl4srec'): - def __init__( - self, - current_representation, - all_items_representation, - tau=1.0, - output_prefix=None, - ): - super().__init__() - self._current_representation = current_representation - self._all_items_representation = all_items_representation - self._loss_function = nn.CrossEntropyLoss() - self._tau = tau - self._output_prefix = output_prefix - - def forward(self, inputs): - current_representation = inputs[ - self._current_representation - ] # (batch_size, embedding_dim) - all_items_representation = inputs[ - self._all_items_representation - ] # (batch_size, num_negatives + 1, embedding_dim) - - batch_size = current_representation.shape[0] - - logits = torch.einsum( - 'bnd,bd->bn', - all_items_representation, - current_representation, - ) # (batch_size, num_negatives + 1) - labels = logits.new_zeros(batch_size) # (batch_size) - - loss = self._loss_function(logits, labels) - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class DuorecSSLLoss(TorchLoss, config_name='duorec_ssl'): - def __init__( - self, - original_embedding_prefix, - dropout_embedding_prefix, - similar_embedding_prefix, - normalize_embeddings=False, - tau=1.0, - output_prefix=None, - ): - super().__init__() - self._original_embedding_prefix = original_embedding_prefix - self._dropout_embedding_prefix = dropout_embedding_prefix - self._similar_embedding_prefix = similar_embedding_prefix - self._normalize_embeddings = normalize_embeddings - self._output_prefix = output_prefix - self._tau = tau - self._loss_function = nn.CrossEntropyLoss(reduction='mean') - - def _compute_partial_loss(self, fst_embeddings, snd_embeddings): - batch_size = fst_embeddings.shape[0] - - combined_embeddings = torch.cat( - (fst_embeddings, snd_embeddings), - dim=0, - ) # (2 * x, embedding_dim) - - if self._normalize_embeddings: - combined_embeddings = torch.nn.functional.normalize( - combined_embeddings, - p=2, - dim=-1, - eps=1e-6, - ) - - similarity_scores = ( - torch.mm(combined_embeddings, combined_embeddings.T) / self._tau - ) # (2 * x, 2 * x) - - positive_samples = torch.cat( - ( - torch.diag(similarity_scores, batch_size), - torch.diag(similarity_scores, -batch_size), - ), - dim=0, - ).reshape(2 * batch_size, 1) # (2 * x, 1) - - # TODO optimize - mask = torch.ones( - 2 * batch_size, - 2 * batch_size, - dtype=torch.bool, - ) # (2 * x, 2 * x) - mask = mask.fill_diagonal_(0) # Remove equal embeddings scores - for i in range(batch_size): # Remove positives - mask[i, batch_size + i] = 0 - mask[batch_size + i, i] = 0 - - negative_samples = similarity_scores[mask].reshape( - 2 * batch_size, - -1, - ) # (2 * x, 2 * x - 2) - - labels = ( - torch.zeros(2 * batch_size).to(positive_samples.device).long() - ) # (2 * x) - logits = torch.cat( - (positive_samples, negative_samples), - dim=1, - ) # (2 * x, 2 * x - 1) - - loss = self._loss_function(logits, labels) / 2 # (1) - - return loss - - def forward(self, inputs): - original_embeddings = inputs[ - self._original_embedding_prefix - ] # (x, embedding_dim) - dropout_embeddings = inputs[ - self._dropout_embedding_prefix - ] # (x, embedding_dim) - similar_embeddings = inputs[ - self._similar_embedding_prefix - ] # (x, embedding_dim) - - dropout_loss = self._compute_partial_loss( - original_embeddings, - dropout_embeddings, - ) - ssl_loss = self._compute_partial_loss( - original_embeddings, - similar_embeddings, - ) - - loss = dropout_loss + ssl_loss - - if self._output_prefix is not None: - inputs[f'{self._output_prefix}_dropout'] = ( - dropout_loss.cpu().item() - ) - inputs[f'{self._output_prefix}_ssl'] = ssl_loss.cpu().item() - inputs[self._output_prefix] = loss.cpu().item() - - return loss - - -class MCLSRLoss(TorchLoss, config_name='mclsr'): - def __init__( - self, - all_scores_prefix, - mask_prefix, - normalize_embeddings=False, - tau=1.0, - output_prefix=None, - ): - super().__init__() - self._all_scores_prefix = all_scores_prefix - self._mask_prefix = mask_prefix - self._normalize_embeddings = normalize_embeddings - self._output_prefix = output_prefix - self._tau = tau - - def forward(self, inputs): - all_scores = inputs[ - self._all_scores_prefix - ] # (batch_size, batch_size, seq_len) - mask = inputs[self._mask_prefix] # (batch_size) - - batch_size = mask.shape[0] - seq_len = mask.shape[1] - - positive_mask = torch.eye(batch_size, device=mask.device).bool() - - positive_scores = all_scores[positive_mask] # (batch_size, seq_len) - negative_scores = torch.reshape( - all_scores[~positive_mask], - shape=(batch_size, batch_size - 1, seq_len), - ) # (batch_size, batch_size - 1, seq_len) - assert torch.allclose(all_scores[0, 1], negative_scores[0, 0]) - assert torch.allclose(all_scores[-1, -2], negative_scores[-1, -1]) - assert torch.allclose(all_scores[0, 0], positive_scores[0]) - assert torch.allclose(all_scores[-1, -1], positive_scores[-1]) - - # Maybe try mean over sequence TODO - loss = torch.sum( - torch.log( - torch.sigmoid(positive_scores.unsqueeze(1) - negative_scores), - ), - ) # (1) - - if self._output_prefix is not None: - inputs[self._output_prefix] = loss.cpu().item() - - return loss diff --git a/src/irec/metric/__init__.py b/src/irec/metric/__init__.py deleted file mode 100644 index 032b3f99..00000000 --- a/src/irec/metric/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .base import BaseMetric, StatefullMetric, StaticMetric - -__all__ = [ - 'BaseMetric', - 'StatefullMetric', - 'StaticMetric', -] diff --git a/src/irec/metric/base.py b/src/irec/metric/base.py deleted file mode 100644 index da689653..00000000 --- a/src/irec/metric/base.py +++ /dev/null @@ -1,202 +0,0 @@ -from irec.utils import MetaParent - -import torch - - -class BaseMetric(metaclass=MetaParent): - pass - - -class StatefullMetric(BaseMetric): - def reduce(self): - raise NotImplementedError - - -class StaticMetric(BaseMetric, config_name='dummy'): - def __init__(self, name, value): - self._name = name - self._value = value - - def __call__(self, inputs): - inputs[self._name] = self._value - - return inputs - - -class CompositeMetric(BaseMetric, config_name='composite'): - def __init__(self, metrics): - self._metrics = metrics - - @classmethod - def create_from_config(cls, config): - return cls( - metrics=[ - BaseMetric.create_from_config(cfg) for cfg in config['metrics'] - ], - ) - - def __call__(self, inputs): - for metric in self._metrics: - inputs = metric(inputs) - return inputs - - -class NDCGMetric(BaseMetric, config_name='ndcg'): - def __init__(self, k): - self._k = k - - def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][ - :, - : self._k, - ].float() # (batch_size, top_k_indices) - labels = inputs['{}.ids'.format(labels_prefix)].float() # (batch_size) - - - assert labels.shape[0] == predictions.shape[0] - - hits = torch.eq( - predictions, - labels[..., None], - ).float() # (batch_size, top_k_indices) - discount_factor = 1 / torch.log2( - torch.arange(1, self._k + 1, 1).float() + 1.0, - ).to(hits.device) # (k) - dcg = torch.einsum('bk,k->b', hits, discount_factor) # (batch_size) - - return dcg.cpu().tolist() - - -class RecallMetric(BaseMetric, config_name='recall'): - def __init__(self, k): - self._k = k - - def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][ - :, - : self._k, - ].float() # (batch_size, top_k_indices) - labels = inputs['{}.ids'.format(labels_prefix)].float() # (batch_size) - - assert labels.shape[0] == predictions.shape[0] - - hits = torch.eq( - predictions, - labels[..., None], - ).float() # (batch_size, top_k_indices) - recall = hits.sum(dim=-1) # (batch_size) - - return recall.cpu().tolist() - - -class CoverageMetric(StatefullMetric, config_name='coverage'): - def __init__(self, k, num_items): - self._k = k - self._num_items = num_items - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls(k=config['k'], num_items=kwargs['num_items']) - - def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][ - :, - : self._k, - ].float() # (batch_size, top_k_indices) - return ( - predictions.view(-1).long().cpu().detach().tolist() - ) # (batch_size * k) - - def reduce(self, values): - return len(set(values)) / self._num_items - -class MCLSRNDCGMetric(BaseMetric, config_name='mclsr-ndcg'): - def __init__(self, k): - self._k = k - - def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k) - labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,) - labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,) - - assert predictions.shape[0] == labels_lengths.shape[0] - - dcg_scores = [] - offset = 0 - for i in range(predictions.shape[0]): - user_predictions = predictions[i] - num_user_labels = labels_lengths[i] - user_labels = labels_flat[offset : offset + num_user_labels] - offset += num_user_labels - - hits_mask = torch.isin(user_predictions, user_labels) # (k,) -> True/False - - positions = torch.arange(2, self._k + 2, device=predictions.device) - weights = 1 / torch.log2(positions.float()) - dcg = (hits_mask.float() * weights).sum() - - num_ideal_hits = min(self._k, num_user_labels) - idcg_weights = weights[:num_ideal_hits] - idcg = idcg_weights.sum() - - ndcg = dcg / idcg if idcg > 0 else torch.tensor(0.0) - dcg_scores.append(ndcg.cpu().item()) - - return dcg_scores - - -class MCLSRRecallMetric(BaseMetric, config_name='mclsr-recall'): - def __init__(self, k): - self._k = k - - def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k) - labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,) - labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,) - - assert predictions.shape[0] == labels_lengths.shape[0] - - recall_scores = [] - offset = 0 - for i in range(predictions.shape[0]): - user_predictions = predictions[i] - num_user_labels = labels_lengths[i] - user_labels = labels_flat[offset : offset + num_user_labels] - offset += num_user_labels - - hits = torch.isin(user_predictions, user_labels).sum().float() - - recall = hits / num_user_labels if num_user_labels > 0 else torch.tensor(0.0) - recall_scores.append(recall.cpu().item()) - - return recall_scores - -class MCLSRHitRateMetric(BaseMetric, config_name='mclsr-hit'): - def __init__(self, k): - self._k = k - - def __call__(self, inputs, pred_prefix, labels_prefix): - predictions = inputs[pred_prefix][:, :self._k] # (batch_size, k) - labels_flat = inputs[f'{labels_prefix}.ids'] # (total_labels,) - labels_lengths = inputs[f'{labels_prefix}.length'] # (batch_size,) - - assert predictions.shape[0] == labels_lengths.shape[0] - - hit_scores = [] - offset = 0 - for i in range(predictions.shape[0]): - user_predictions = predictions[i] - num_user_labels = labels_lengths[i] - - if num_user_labels == 0: - hit_scores.append(0.0) - continue - - user_labels = labels_flat[offset : offset + num_user_labels] - offset += num_user_labels - - is_hit = torch.isin(user_predictions, user_labels).any() - - hit_scores.append(float(is_hit)) - - return hit_scores \ No newline at end of file diff --git a/src/irec/models/__init__.py b/src/irec/models/__init__.py index 6e0fbd57..8c8013de 100644 --- a/src/irec/models/__init__.py +++ b/src/irec/models/__init__.py @@ -1,41 +1,7 @@ -from .base import BaseModel, SequentialTorchModel, TorchModel -from .bert4rec import Bert4RecModel -from .bert4rec_cls import Bert4RecModelCLS -from .cl4srec import Cl4SRecModel -from .duorec import DuoRecModel -from .graph_seq_rec import GraphSeqRecModel -from .gru4rec import GRU4RecModel -from .lightgcn import LightGCNModel -from .mclsr import MCLSRModel -from .mrgsrec import MRGSRecModel -from .ngcf import NgcfModel -from .pop import PopModel -from .pure_mf import PureMF -from .random import RandomModel -from .sasrec import SasRecModel, SasRecInBatchModel -from .sasrec_ce import SasRecCeModel -from .s3rec import S3RecModel +from irec.models.base import create_masked_tensor, TorchModel + __all__ = [ - 'BaseModel', - 'SequentialTorchModel', + 'create_masked_tensor', 'TorchModel', - 'Bert4RecModel', - 'Bert4RecModelCLS', - 'Cl4SRecModel', - 'DuoRecModel', - 'GraphSeqRecModel', - 'GRU4RecModel', - 'LightGCNModel', - 'MCLSRModel', - 'MRGSRecModel', - 'NgcfModel', - 'PopModel', - 'PureMF', - 'RandomModel', - 'SasRecModel', - 'SasRecInBatchModel', - 'SasRecCeModel', - 'S3RecModel', - 'SasRecRealModel', ] diff --git a/src/irec/models/base.py b/src/irec/models/base.py index 2f059f7c..12699210 100644 --- a/src/irec/models/base.py +++ b/src/irec/models/base.py @@ -1,23 +1,36 @@ -from irec.utils import MetaParent - -from irec.utils import ( - DEVICE, - create_masked_tensor, - get_activation_function, - create_logger, -) - import torch import torch.nn as nn -logger = create_logger(name=__name__) +def create_masked_tensor(data, lengths, is_right_aligned=False): + batch_size = lengths.shape[0] + max_sequence_length = lengths.max().item() + + if len(data.shape) == 1: # only indices + padded_tensor = torch.zeros( + batch_size, max_sequence_length, + dtype=data.dtype, device=data.device + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = torch.zeros( + batch_size, max_sequence_length, data.shape[-1], + dtype=data.dtype, device=data.device + ) # (batch_size, max_seq_len, emb_dim) -class BaseModel(metaclass=MetaParent): - pass + mask = torch.arange( + end=max_sequence_length, + device=data.device + )[None].tile([batch_size, 1]) < lengths[:, None] # (batch_size, max_seq_len) + if is_right_aligned: + mask = torch.flip(mask, dims=[-1]) + padded_tensor[mask] = data -class TorchModel(nn.Module, BaseModel): + return padded_tensor, mask + + +class TorchModel(nn.Module): @torch.no_grad() def _init_weights(self, initializer_range): for key, value in self.named_parameters(): @@ -29,195 +42,23 @@ def _init_weights(self, initializer_range): value.data, std=initializer_range, a=-2 * initializer_range, - b=2 * initializer_range, + b=2 * initializer_range ) elif 'bias' in key: nn.init.zeros_(value.data) + elif 'codebook' in key: + nn.init.trunc_normal_( + value.data, + std=initializer_range, + a=-2 * initializer_range, + b=2 * initializer_range + ) + elif 'bos_embedding' in key: + nn.init.trunc_normal_( + value.data, + std=initializer_range, + a=-2 * initializer_range, + b=2 * initializer_range, + ) else: - raise ValueError(f'Unknown transformer weight: {key}') - - @staticmethod - def _get_last_embedding(embeddings, mask): - lengths = torch.sum(mask, dim=-1) # (batch_size) - lengths = lengths - 1 # (batch_size) - last_masks = mask.gather( - dim=1, - index=lengths[:, None], - ) # (batch_size, 1) - lengths = torch.tile( - lengths[:, None, None], - (1, 1, embeddings.shape[-1]), - ) # (batch_size, 1, emb_dim) - last_embeddings = embeddings.gather( - dim=1, - index=lengths, - ) # (batch_size, 1, emb_dim) - last_embeddings = last_embeddings[last_masks] # (batch_size, emb_dim) - if not torch.allclose(embeddings[mask][-1], last_embeddings[-1]): - logger.debug(f'Embeddings: {embeddings}') - logger.debug( - f'Lengths: {lengths}, max: {lengths.max()}, min: {lengths.min()}', - ) - logger.debug(f'Last embedding from mask: {embeddings[mask][-1]}') - logger.debug(f'Last embedding from gather: {last_embeddings[-1]}') - assert False - return last_embeddings - - -class SequentialTorchModel(TorchModel): - def __init__( - self, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-5, - is_causal=True, - ): - super().__init__() - self._is_causal = is_causal - self._num_items = num_items - self._num_heads = num_heads - self._embedding_dim = embedding_dim - - self._item_embeddings = nn.Embedding( - num_embeddings=num_items - + 2, # add zero embedding + mask embedding - embedding_dim=embedding_dim, - ) - self._position_embeddings = nn.Embedding( - num_embeddings=max_sequence_length - + 1, # in order to include `max_sequence_length` value - embedding_dim=embedding_dim, - ) - - self._layernorm = nn.LayerNorm(embedding_dim, eps=layer_norm_eps) - self._dropout = nn.Dropout(dropout) - - transformer_encoder_layer = nn.TransformerEncoderLayer( - d_model=embedding_dim, - nhead=num_heads, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=get_activation_function(activation), - layer_norm_eps=layer_norm_eps, - batch_first=True, - ) - self._encoder = nn.TransformerEncoder( - transformer_encoder_layer, - num_layers, - ) - - def _apply_sequential_encoder(self, events, lengths, add_cls_token=False): - embeddings = self._item_embeddings( - events, - ) # (all_batch_events, embedding_dim) - - embeddings, mask = create_masked_tensor( - data=embeddings, - lengths=lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - batch_size = mask.shape[0] - seq_len = mask.shape[1] - - positions = ( - torch.arange( - start=seq_len - 1, - end=-1, - step=-1, - device=mask.device, - )[None] - .tile([batch_size, 1]) - .long() - ) # (batch_size, seq_len) - positions_mask = ( - positions < lengths[:, None] - ) # (batch_size, max_seq_len) - - positions = positions[positions_mask] # (all_batch_events) - position_embeddings = self._position_embeddings( - positions, - ) # (all_batch_events, embedding_dim) - position_embeddings, _ = create_masked_tensor( - data=position_embeddings, - lengths=lengths, - ) # (batch_size, seq_len, embedding_dim) - assert torch.allclose(position_embeddings[~mask], embeddings[~mask]) - - embeddings = ( - embeddings + position_embeddings - ) # (batch_size, seq_len, embedding_dim) - - embeddings = self._layernorm( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = self._dropout( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - - embeddings[~mask] = 0 - - if add_cls_token: - cls_token_tensor = self._cls_token.unsqueeze(0).unsqueeze(0) - cls_token_expanded = torch.tile( - cls_token_tensor, - (batch_size, 1, 1), - ) - embeddings = torch.cat((cls_token_expanded, embeddings), dim=1) - mask = torch.cat( - ( - torch.ones( - (batch_size, 1), - dtype=torch.bool, - device=DEVICE, - ), - mask, - ), - dim=1, - ) - - if self._is_causal: - causal_mask = ( - torch.tril(torch.ones(seq_len, seq_len)).bool().to(DEVICE) - ) # (seq_len, seq_len) - embeddings = self._encoder( - src=embeddings, - mask=~causal_mask, - src_key_padding_mask=~mask, - ) # (batch_size, seq_len, embedding_dim) - else: - embeddings = self._encoder( - src=embeddings, - src_key_padding_mask=~mask, - ) # (batch_size, seq_len, embedding_dim) - - return embeddings, mask - - @staticmethod - def _add_cls_token(items, lengths, cls_token_id=0): - num_items = items.shape[0] - batch_size = lengths.shape[0] - num_new_items = num_items + batch_size - - new_items = ( - torch.ones(num_new_items, dtype=items.dtype, device=items.device) - * cls_token_id - ) # (num_new_items) - - old_items_mask = torch.zeros_like(new_items).bool() # (num_new_items) - old_items_mask = ~old_items_mask.scatter( - src=torch.ones_like(lengths).bool(), - dim=0, - index=torch.cat( - [torch.LongTensor([0]).to(DEVICE), lengths + 1], - ).cumsum(dim=0)[:-1], - ) # (num_new_items) - new_items[old_items_mask] = items - new_length = lengths + 1 - - return new_items, new_length + raise ValueError(f'Unknown transformer weight: {key}') \ No newline at end of file diff --git a/src/irec/models/bert4rec.py b/src/irec/models/bert4rec.py deleted file mode 100644 index a56cbfa1..00000000 --- a/src/irec/models/bert4rec.py +++ /dev/null @@ -1,129 +0,0 @@ -from .base import SequentialTorchModel - -import torch -import torch.nn as nn - - -class Bert4RecModel(SequentialTorchModel, config_name='bert4rec'): - def __init__( - self, - sequence_prefix, - labels_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='gelu', - layer_norm_eps=1e-5, - initializer_range=0.02, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=False, - ) - self._sequence_prefix = sequence_prefix - self._labels_prefix = labels_prefix - - self._output_projection = nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - ) - - self._bias = nn.Parameter( - data=torch.zeros(num_items + 2), - requires_grad=True, - ) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - labels_prefix=config['labels_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get( - 'num_heads', - int(config['embedding_dim'] // 64), - ), - num_layers=config['num_layers'], - dim_feedforward=config.get( - 'dim_feedforward', - 4 * config['embedding_dim'], - ), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - ) - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, - all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - embeddings = self._output_projection( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = torch.nn.functional.gelu( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = torch.einsum( - 'bsd,nd->bsn', - embeddings, - self._item_embeddings.weight, - ) # (batch_size, seq_len, num_items) - embeddings += self._bias[ - None, - None, - :, - ] # (batch_size, seq_len, num_items) - - if self.training: # training mode - all_sample_labels = inputs[ - '{}.ids'.format(self._labels_prefix) - ] # (all_batch_events) - embeddings = embeddings[mask] # (all_batch_events, num_items) - labels_mask = (all_sample_labels != 0).bool() # (all_batch_events) - - needed_logits = embeddings[ - labels_mask - ] # (non_zero_events, num_items) - needed_labels = all_sample_labels[labels_mask] # (non_zero_events) - - return {'logits': needed_logits, 'labels.ids': needed_labels} - else: # eval mode - candidate_scores = self._get_last_embedding( - embeddings, - mask, - ) # (batch_size, num_items) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/models/bert4rec_cls.py b/src/irec/models/bert4rec_cls.py deleted file mode 100644 index 19b1004c..00000000 --- a/src/irec/models/bert4rec_cls.py +++ /dev/null @@ -1,114 +0,0 @@ -from .base import SequentialTorchModel - -import torch -import torch.nn as nn - - -class Bert4RecModelCLS(SequentialTorchModel, config_name='bert4rec_cls'): - def __init__( - self, - sequence_prefix, - labels_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='gelu', - layer_norm_eps=1e-5, - initializer_range=0.02, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=False, - ) - self._sequence_prefix = sequence_prefix - self._labels_prefix = labels_prefix - - self._output_projection = nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - ) - - self._bias = nn.Parameter( - data=torch.zeros(num_items + 2), - requires_grad=True, - ) - - self._init_weights(initializer_range) - - self._cls_token = nn.Parameter(torch.rand(embedding_dim)) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - labels_prefix=config['labels_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get( - 'num_heads', - int(config['embedding_dim'] // 64), - ), - num_layers=config['num_layers'], - dim_feedforward=config.get( - 'dim_feedforward', - 4 * config['embedding_dim'], - ), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - ) - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - events=all_sample_events, - lengths=all_sample_lengths, - add_cls_token=True, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - embeddings = self._output_projection( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - predictions = embeddings[:, 0, :] # (batch_size, embedding_dim) - - if self.training: # training mode - candidates = self._item_embeddings( - inputs['{}.ids'.format(self._labels_prefix)], - ) # (batch_size, embedding_dim) - - return {'predictions': predictions, 'candidates': candidates} - else: # eval mode - candidate_scores = torch.einsum( - 'bd,nd->bn', - predictions, - self._item_embeddings.weight, - ) # (batch_size, num_items + 2) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/models/cl4srec.py b/src/irec/models/cl4srec.py deleted file mode 100644 index 5023c2af..00000000 --- a/src/irec/models/cl4srec.py +++ /dev/null @@ -1,159 +0,0 @@ -from .base import SequentialTorchModel - -import torch - - -class Cl4SRecModel(SequentialTorchModel, config_name='cl4srec'): - def __init__( - self, - sequence_prefix, - fst_augmented_sequence_prefix, - snd_augmented_sequence_prefix, - positive_prefix, - negative_prefix, - labels_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-5, - initializer_range=0.02, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=True, - ) - self._sequence_prefix = sequence_prefix - self._fst_augmented_sequence_prefix = fst_augmented_sequence_prefix - self._snd_augmented_sequence_prefix = snd_augmented_sequence_prefix - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._labels_prefix = labels_prefix - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - fst_augmented_sequence_prefix=config[ - 'fst_augmented_sequence_prefix' - ], - snd_augmented_sequence_prefix=config[ - 'snd_augmented_sequence_prefix' - ], - positive_prefix=config['positive_prefix'], - negative_prefix=config['negative_prefix'], - labels_prefix=config['labels_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - num_layers=config['num_layers'], - num_heads=config['num_heads'], - embedding_dim=config['embedding_dim'], - dim_feedforward=config['dim_feedforward'], - dropout=config['dropout'], - activation=config['activation'], - layer_norm_eps=config['layer_norm_eps'], - initializer_range=config['initializer_range'], - ) - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, - all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - last_embeddings = self._get_last_embedding( - embeddings, - mask, - ) # (batch_size, embedding_dim) - - if self.training: # training mode - items_logits = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight, - ) # (batch_size, num_items) - - # TODO remove this check - labels = inputs[ - '{}.ids'.format(self._labels_prefix) - ] # (batch_size) - assert torch.allclose( - self._item_embeddings(labels), - self._item_embeddings.weight[labels], - ) - - all_fst_aug_sample_events = inputs[ - '{}.ids'.format(self._fst_augmented_sequence_prefix) - ] # (all_batch_events) - all_fst_aug_sample_lengths = inputs[ - '{}.length'.format(self._fst_augmented_sequence_prefix) - ] # (batch_size) - fst_aug_embeddings, fst_aug_mask = self._apply_sequential_encoder( - all_fst_aug_sample_events, - all_fst_aug_sample_lengths, - ) # (batch_size, fst_aug_seq_len, embedding_dim), (batch_size, fst_aug_seq_len) - last_fst_aug_embeddings = self._get_last_embedding( - fst_aug_embeddings, - fst_aug_mask, - ) # (batch_size, embedding_dim) - - all_snd_aug_sample_events = inputs[ - '{}.ids'.format(self._snd_augmented_sequence_prefix) - ] # (all_batch_events) - all_snd_aug_sample_lengths = inputs[ - '{}.length'.format(self._snd_augmented_sequence_prefix) - ] # (batch_size) - snd_aug_embeddings, snd_aug_mask = self._apply_sequential_encoder( - all_snd_aug_sample_events, - all_snd_aug_sample_lengths, - ) # (batch_size, snd_aug_seq_len, embedding_dim), (batch_size, snd_aug_seq_len) - last_snd_aug_embeddings = self._get_last_embedding( - snd_aug_embeddings, - snd_aug_mask, - ) # (batch_size, embedding_dim) - - return { - 'logits': items_logits, - 'sequence_representation': last_embeddings, - 'fst_aug_sequence_representation': last_fst_aug_embeddings, - 'snd_aug_sequence_representation': last_snd_aug_embeddings, - } - else: # eval mode - candidate_embeddings = ( - self._item_embeddings.weight - ) # (num_items, embedding_dim) - candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - candidate_embeddings, - ) # (batch_size, num_items) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/models/deepfm.py b/src/irec/models/deepfm.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/irec/models/duorec.py b/src/irec/models/duorec.py deleted file mode 100644 index 0056f3fd..00000000 --- a/src/irec/models/duorec.py +++ /dev/null @@ -1,172 +0,0 @@ -from .base import SequentialTorchModel - -import torch -import torch.nn as nn - - -class DuoRecModel(SequentialTorchModel, config_name='duorec'): - def __init__( - self, - sequence_prefix, - augmented_sequence_prefix, - labels_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-5, - initializer_range=0.02, - is_causal=True, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=is_causal, - ) - self._sequence_prefix = sequence_prefix - self._augmented_sequence_prefix = augmented_sequence_prefix - self._labels_prefix = labels_prefix - - # TODO taken from duorec github - # self._init_weights(initializer_range) - self._initializer_range = initializer_range - self.apply(self._init_weights) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - augmented_sequence_prefix=config['augmented_sequence_prefix'], - labels_prefix=config['labels_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - num_layers=config['num_layers'], - num_heads=config['num_heads'], - embedding_dim=config['embedding_dim'], - dim_feedforward=config['dim_feedforward'], - dropout=config['dropout'], - activation=config['activation'], - layer_norm_eps=config['layer_norm_eps'], - initializer_range=config['initializer_range'], - ) - - # TODO taken from duorec github - def _init_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self._initializer_range) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, - all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - last_embeddings = self._get_last_embedding( - embeddings, - mask, - ) # (batch_size, embedding_dim) - - if self.training: # training mode - items_logits = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight, - ) # (batch_size, num_items) - training_output = { - 'logits': items_logits, - 'sequence_representation': last_embeddings, - } - - # TODO remove this check - labels = inputs[ - '{}.ids'.format(self._labels_prefix) - ] # (batch_size) - assert torch.allclose( - self._item_embeddings(labels), - self._item_embeddings.weight[labels], - ) - - # Unsupervised Augmentation - embeddings_, mask_ = self._apply_sequential_encoder( - all_sample_events, - all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - last_embeddings_ = self._get_last_embedding( - embeddings_, - mask_, - ) # (batch_size, embedding_dim) - training_output['similar_sequence_representation'] = ( - last_embeddings_ - ) - assert not torch.allclose( - last_embeddings, - last_embeddings_, - ), 'Embedding must be different because of dropout' - - # Semantic Similarity - all_sample_augmented_events = inputs[ - '{}.ids'.format(self._augmented_sequence_prefix) - ] # (all_batch_events) - all_sample_augmented_lengths = inputs[ - '{}.length'.format(self._augmented_sequence_prefix) - ] # (batch_size) - - augmented_embeddings, augmented_mask = ( - self._apply_sequential_encoder( - all_sample_augmented_events, - all_sample_augmented_lengths, - ) - ) # (batch_size, aug_seq_len, embedding_dim), (batch_size, aug_seq_len) - last_augmented_embeddings = self._get_last_embedding( - augmented_embeddings, - augmented_mask, - ) # (batch_size, embedding_dim) - training_output['augmented_sequence_representation'] = ( - last_augmented_embeddings - ) - - return training_output - else: # eval mode - candidate_embeddings = ( - self._item_embeddings.weight - ) # (num_items, embedding_dim) - candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - candidate_embeddings, - ) # (batch_size, num_items) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/models/graph_seq_rec.py b/src/irec/models/graph_seq_rec.py deleted file mode 100644 index 95464e09..00000000 --- a/src/irec/models/graph_seq_rec.py +++ /dev/null @@ -1,313 +0,0 @@ -from .base import SequentialTorchModel - -from irec.utils import create_masked_tensor, DEVICE - -import torch -import torch.nn as nn - - -class GraphSeqRecModel(SequentialTorchModel, config_name='graph_seq_rec'): - def __init__( - self, - sequence_prefix, - positive_prefix, - negative_prefix, - candidate_prefix, - common_graph, - user_graph, - item_graph, - num_hops, - graph_dropout, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - use_ce=False, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=True, - ) - self._sequence_prefix = sequence_prefix - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._candidate_prefix = candidate_prefix - - self._use_ce = use_ce - - self._common_graph = common_graph - self._user_graph = user_graph - self._item_graph = item_graph - self._num_hops = num_hops - self._graph_dropout = graph_dropout - - self._output_projection = nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - ) - - self._bias = nn.Parameter( - data=torch.zeros(num_items + 2), - requires_grad=True, - ) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - negative_prefix=config['negative_prefix'], - candidate_prefix=config['candidate_prefix'], - common_graph=kwargs['graph'], - user_graph=kwargs['user_graph'], - item_graph=kwargs['item_graph'], - num_hops=config['num_hops'], - graph_dropout=config['graph_dropout'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get( - 'num_heads', - int(config['embedding_dim'] // 64), - ), - num_layers=config['num_layers'], - dim_feedforward=config.get( - 'dim_feedforward', - 4 * config['embedding_dim'], - ), - dropout=config.get('dropout', 0.0), - use_ce=config.get('use_ce', False), - initializer_range=config.get('initializer_range', 0.02), - ) - - def _apply_graph_encoder(self, embeddings, graph): - if self.training: # training_mode - size = graph.size() - index = graph.indices().t() - values = graph.values() - dropout_mask = torch.rand(len(values)) + self._graph_dropout - dropout_mask = dropout_mask.int().bool() - index = index[~dropout_mask] - values = values[~dropout_mask] / (1.0 - self._graph_dropout) - graph_dropped = torch.sparse.FloatTensor(index.t(), values, size) - else: # eval mode - graph_dropped = graph - - for _ in range(self._num_hops): - embeddings = torch.sparse.mm(graph_dropped, embeddings) - - return embeddings - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - common_graph_embeddings = self._apply_graph_encoder( - embeddings=self._item_embeddings.weight, - graph=self._item_graph, - ) # (num_items + 2, embedding_dim) - - embeddings = common_graph_embeddings[ - all_sample_events - ] # (all_batch_events, embedding_dim) - - embeddings, mask = create_masked_tensor( - data=embeddings, - lengths=lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - batch_size = mask.shape[0] - seq_len = mask.shape[1] - - positions = ( - torch.arange( - start=seq_len - 1, - end=-1, - step=-1, - device=mask.device, - )[None] - .tile([batch_size, 1]) - .long() - ) # (batch_size, seq_len) - positions_mask = ( - positions < lengths[:, None] - ) # (batch_size, max_seq_len) - - positions = positions[positions_mask] # (all_batch_events) - position_embeddings = self._position_embeddings( - positions, - ) # (all_batch_events, embedding_dim) - position_embeddings, _ = create_masked_tensor( - data=position_embeddings, - lengths=lengths, - ) # (batch_size, seq_len, embedding_dim) - assert torch.allclose(position_embeddings[~mask], embeddings[~mask]) - - embeddings = ( - embeddings + position_embeddings - ) # (batch_size, seq_len, embedding_dim) - - embeddings = self._layernorm( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = self._dropout( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - - embeddings[~mask] = 0 - - if self._is_causal: - causal_mask = ( - torch.tril( - torch.tile( - mask.unsqueeze(1), - dims=[self._num_heads, seq_len, 1], - ), - ) - .bool() - .to(DEVICE) - ) # (seq_len, seq_len) - embeddings = self._encoder( - src=embeddings, - mask=~causal_mask, - ) # (batch_size, seq_len, embedding_dim) - else: - embeddings = self._encoder( - src=embeddings, - src_key_padding_mask=~mask, - ) # (batch_size, seq_len, embedding_dim) - - if self._use_ce: - embeddings = self._output_projection( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = torch.nn.functional.gelu( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = torch.einsum( - 'bsd,nd->bsn', - embeddings, - self._item_embeddings.weight, - ) # (batch_size, seq_len, num_items) - embeddings += self._bias[ - None, - None, - :, - ] # (batch_size, seq_len, num_items) - else: - last_embeddings = self._get_last_embedding( - embeddings, - mask, - ) # (batch_size, embedding_dim) - - if self.training: # training mode - if self._use_ce: - return {'logits': embeddings[mask]} - else: - all_positive_sample_events = inputs[ - '{}.ids'.format(self._positive_prefix) - ] # (all_batch_events) - all_negative_sample_events = inputs[ - '{}.ids'.format(self._negative_prefix) - ] # (all_batch_events) - - all_sample_embeddings = embeddings[ - mask - ] # (all_batch_events, embedding_dim) - all_positive_sample_embeddings = self._item_embeddings( - all_positive_sample_events, - ) # (all_batch_events, embedding_dim) - all_negative_sample_embeddings = self._item_embeddings( - all_negative_sample_events, - ) # (all_batch_events, embedding_dim) - - return { - 'current_embeddings': all_sample_embeddings, - 'positive_embeddings': all_positive_sample_embeddings, - 'negative_embeddings': all_negative_sample_embeddings, - } - else: # eval mode - if self._use_ce: - last_embeddings = self._get_last_embedding( - embeddings, - mask, - ) # (batch_size, num_items) - - if '{}.ids'.format(self._candidate_prefix) in inputs: - candidate_events = inputs[ - '{}.ids'.format(self._candidate_prefix) - ] # (all_batch_candidates) - candidate_lengths = inputs[ - '{}.length'.format(self._candidate_prefix) - ] # (batch_size) - - candidate_ids = torch.reshape( - candidate_events, - (candidate_lengths.shape[0], candidate_lengths[0]), - ) # (batch_size, num_candidates) - candidate_scores = last_embeddings.gather( - dim=1, - index=candidate_ids, - ) # (batch_size, num_candidates) - else: - candidate_scores = ( - last_embeddings # (batch_size, num_items + 2) - ) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - else: - if '{}.ids'.format(self._candidate_prefix) in inputs: - candidate_events = inputs[ - '{}.ids'.format(self._candidate_prefix) - ] # (all_batch_candidates) - candidate_lengths = inputs[ - '{}.length'.format(self._candidate_prefix) - ] # (batch_size) - - candidate_embeddings = self._item_embeddings( - candidate_events, - ) # (all_batch_candidates, embedding_dim) - - candidate_embeddings, _ = create_masked_tensor( - data=candidate_embeddings, - lengths=candidate_lengths, - ) # (batch_size, num_candidates, embedding_dim) - - candidate_scores = torch.einsum( - 'bd,bnd->bn', - last_embeddings, - candidate_embeddings, - ) # (batch_size, num_candidates) - else: - candidate_embeddings = ( - self._item_embeddings.weight - ) # (num_items, embedding_dim) - candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - candidate_embeddings, - ) # (batch_size, num_items) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - return candidate_scores diff --git a/src/irec/models/gru4rec.py b/src/irec/models/gru4rec.py deleted file mode 100644 index 392fb069..00000000 --- a/src/irec/models/gru4rec.py +++ /dev/null @@ -1,264 +0,0 @@ -from .base import TorchModel - -from irec.utils import create_masked_tensor, get_activation_function - -import torch -from torch import nn - - -class GRUModel(TorchModel): - def __init__( - self, - num_items, - max_sequence_length, - embedding_dim, - num_layers, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-5, - ): - super().__init__() - self._num_items = num_items - self._embedding_dim = embedding_dim - self._num_layers = num_layers - - self._item_embeddings = nn.Embedding( - num_embeddings=num_items - + 2, # add zero embedding + mask embedding - embedding_dim=embedding_dim, - ) - self._position_embeddings = nn.Embedding( - num_embeddings=max_sequence_length - + 1, # in order to include `max_sequence_length` value - embedding_dim=embedding_dim, - ) - - self._layernorm = nn.LayerNorm(embedding_dim, eps=layer_norm_eps) - self._dropout = nn.Dropout(dropout) - - self._encoder = nn.GRU( - input_size=embedding_dim, - hidden_size=embedding_dim, - num_layers=num_layers, - batch_first=True, - dropout=dropout, - bidirectional=False, - ) - - self._hidden_to_output_projection = nn.Linear(embedding_dim, num_items) - self._activation = get_activation_function(activation) - - def _apply_sequential_encoder(self, events, lengths): - embeddings = self._item_embeddings( - events, - ) # (all_batch_events, embedding_dim) - - embeddings, mask = create_masked_tensor( - data=embeddings, - lengths=lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - batch_size = mask.shape[0] - seq_len = mask.shape[1] - - positions = ( - torch.arange( - start=seq_len - 1, - end=-1, - step=-1, - device=mask.device, - )[None] - .tile([batch_size, 1]) - .long() - ) # (batch_size, seq_len) - positions_mask = ( - positions < lengths[:, None] - ) # (batch_size, max_seq_len) - - positions = positions[positions_mask] # (all_batch_events) - position_embeddings = self._position_embeddings( - positions, - ) # (all_batch_events, embedding_dim) - position_embeddings, _ = create_masked_tensor( - data=position_embeddings, - lengths=lengths, - ) # (batch_size, seq_len, embedding_dim) - assert torch.allclose(position_embeddings[~mask], embeddings[~mask]) - - embeddings = ( - embeddings + position_embeddings - ) # (batch_size, seq_len, embedding_dim) - - embeddings = self._layernorm( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = self._dropout( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - - embeddings[~mask] = 0 - - packed_embeddings = torch.nn.utils.rnn.pack_padded_sequence( - input=embeddings, - lengths=lengths.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - hidden = torch.zeros( - self._num_layers, - batch_size, - self._embedding_dim, - dtype=embeddings.dtype, - device=embeddings.device, - requires_grad=True, - ) # (num_layers, batch_size, embedding_dim) - out, hidden = self._encoder(packed_embeddings, hidden) - embeddings, embedding_lengths = torch.nn.utils.rnn.pad_packed_sequence( - out, - batch_first=True, - ) # (batch_size, seq_len, embedding_dim) (batch_size, seq_len) - embedding_lengths = embedding_lengths.to(lengths.device) - - assert torch.allclose(lengths, embedding_lengths) - - return embeddings, mask - - -class GRU4RecModel(GRUModel, config_name='gru4rec'): - def __init__( - self, - sequence_prefix, - positive_prefix, - negative_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_layers, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-5, - initializer_range=0.02, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_layers=num_layers, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - ) - - self._sequence_prefix = sequence_prefix - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - negative_prefix=config['negative_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_layers=config['num_layers'], - dropout=config.get('dropout', 0.0), - activation=config.get('activation', 'tanh'), - layer_norm_eps=config.get('layer_norm_eps', 1e-5), - initializer_range=config.get('initializer_range', 0.02), - ) - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - events=all_sample_events, - lengths=all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim) (batch_size, seq_len) - - if self.training: # training mode - all_positive_sample_events = inputs[ - '{}.ids'.format(self._positive_prefix) - ] # (all_batch_events) - all_positive_sample_embeddings = self._item_embeddings( - all_positive_sample_events, - ) # (all_batch_events, embedding_dim) - - all_sample_embeddings = embeddings[ - mask - ] # (all_batch_events, embedding_dim) - - sample_end_idx = torch.cumsum( - all_sample_lengths, - dim=0, - ) # (batch_size) - sample_begin_idx = ( - sample_end_idx - all_sample_lengths - ) # (batch_size) - - sample_end_idx = sample_end_idx[:, None] # (batch_size, 1) - sample_begin_idx = sample_begin_idx[:, None] # (batch_size, 1) - - negative_indices = torch.tile( - torch.arange( - start=0, - end=all_positive_sample_events.shape[0], - device=all_sample_lengths.device, - ).long()[None], - dims=[all_sample_lengths.shape[0], 1], - ) # (batch_size, all_batch_events) - - negative_mask = (negative_indices >= sample_begin_idx) & ( - negative_indices < sample_end_idx - ) - negative_mask = torch.repeat_interleave( - negative_mask, - all_sample_lengths, - dim=0, - ) - - negative_scores = torch.einsum( - 'ad,bd->ab', - all_sample_embeddings, - self._item_embeddings(all_sample_events), - ) # (all_batch_events, all_batch_events) - - positive_scores = torch.einsum( - 'ad,ad->a', - all_sample_embeddings, - all_positive_sample_embeddings, - ) # (all_batch_events) - - return { - 'positive_scores': positive_scores[..., None], - 'negative_scores': negative_scores, - } - else: # eval mode - last_embeddings = self._get_last_embedding( - embeddings, - mask, - ) # (batch_size, embedding_dim) - candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight, - ) # (batch_size, num_items) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/models/gtorec.py b/src/irec/models/gtorec.py deleted file mode 100644 index d6945d63..00000000 --- a/src/irec/models/gtorec.py +++ /dev/null @@ -1,571 +0,0 @@ -from .base import SequentialTorchModel - -from irec.utils import create_masked_tensor, get_activation_function - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class GTOModel(SequentialTorchModel, config_name='gtorec'): - def __init__( - self, - # sequential params - sequence_prefix, # =item_prefix - positive_prefix, - negative_prefix, - candidate_prefix, - source_domain, - num_users, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - # graph params - user_prefix, - graph, - graph_embedding_dim, - graph_num_layers, - # params with default values - dropout=0.0, - graph_dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02, - norm_first=True, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=True, - ) - # sequential part - self._sequence_prefix = sequence_prefix - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._candidate_prefix = candidate_prefix - self._source_domain = source_domain - - self._output_projection = nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - ) - self._bias = nn.Parameter( - data=torch.zeros(num_items + 2), - requires_grad=True, - ) - - # graph part - self._user_prefix = user_prefix - self._num_users = num_users - self._graph = graph - self._graph_embedding_dim = graph_embedding_dim - self._graph_num_layers = graph_num_layers - self._graph_dropout = graph_dropout - - self._graph_user_embeddings = nn.Embedding( - num_embeddings=num_users + 2, - embedding_dim=self._graph_embedding_dim, - ) - self._graph_item_embeddings = nn.Embedding( - num_embeddings=num_items + 2, - embedding_dim=self._graph_embedding_dim, - ) - - # cross_attention part - self._mha = nn.MultiheadAttention( - embed_dim=embedding_dim, - num_heads=num_heads, - dropout=dropout, - bias=True, - add_bias_kv=False, - add_zero_attn=False, - batch_first=True, - ) - - self.linear1 = nn.Linear(embedding_dim, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, embedding_dim) - self.activation = get_activation_function(activation) - - self.norm_first = norm_first - self.norm1 = nn.LayerNorm(embedding_dim, eps=layer_norm_eps) - self.norm2 = nn.LayerNorm(embedding_dim, eps=layer_norm_eps) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self._mha_output_projection = nn.Linear( - in_features=2 * embedding_dim, - out_features=embedding_dim, - ) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - # sequential part - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - negative_prefix=config['negative_prefix'], - candidate_prefix=config['candidate_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get( - 'num_heads', - int(config['embedding_dim'] // 64), - ), - num_layers=config['num_layers'], - dim_feedforward=config.get( - 'dim_feedforward', - 4 * config['embedding_dim'], - ), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - norm_first=config.get('norm_first', True), - # graph part - user_prefix=config['user_prefix'], - num_users=kwargs['num_users'], - graph_embedding_dim=config['graph_embedding_dim'], - graph_num_layers=config['graph_num_layers'], - graph_dropout=config.get('graph_dropout', 0.0), - ) - - def _apply_graph_encoder(self): - ego_embeddings = torch.cat( - ( - self._graph_user_embeddings.weight, - self._graph_item_embeddings.weight, - ), - dim=0, - ) - all_embeddings = [ego_embeddings] - - if self._graph_dropout > 0: # drop some edges - if self.training: # training_mode - size = self._graph.size() - index = self._graph.indices().t() - values = self._graph.values() - random_index = torch.rand(len(values)) + ( - 1 - self._graph_dropout - ) - random_index = random_index.int().bool() - index = index[random_index] - values = values[random_index] / (1 - self._graph_dropout) - graph_dropped = torch.sparse.FloatTensor( - index.t(), - values, - size, - ) - else: # eval mode - graph_dropped = self._graph - else: - graph_dropped = self._graph - - for i in range(self._graph_num_layers): - ego_embeddings = torch.sparse.mm(graph_dropped, ego_embeddings) - norm_embeddings = F.normalize(ego_embeddings, p=2, dim=1) - all_embeddings += [norm_embeddings] - - all_embeddings = torch.cat(all_embeddings, dim=-1) - user_final_embeddings, item_final_embeddings = torch.split( - all_embeddings, - [self._num_users + 2, self._num_items + 2], - ) - - return user_final_embeddings, item_final_embeddings - - def _get_graph_embeddings( - self, - inputs, - prefix, - ego_embeddings, - final_embeddings, - ): - ids = inputs['{}.ids'.format(prefix)] # (batch_size) - lengths = inputs['{}.length'.format(prefix)] # (batch_size) - - final_embeddings = final_embeddings[ids] # (batch_size, emb_dim) - ego_embeddings = ego_embeddings(ids) # (batch_size, emb_dim) - - padded_embeddings, mask = create_masked_tensor( - final_embeddings, - lengths, - ) - padded_ego_embeddings, ego_mask = create_masked_tensor( - ego_embeddings, - lengths, - ) - - assert torch.all(mask == ego_mask) - - return padded_embeddings, padded_ego_embeddings, mask - - def _ca_block(self, q, k, v, attn_mask=None, key_padding_mask=None): - x = self._mha( - q, - k, - v, - attn_mask=attn_mask, - key_padding_mask=key_padding_mask, - need_weights=False, - )[0] # (batch_size, seq_len, embedding_dim) - return self.dropout1(x) # (batch_size, seq_len, embedding_dim) - - def _ff_block(self, x): - x = self.linear2(self.dropout(self.activation(self.linear1(x)))) - return self.dropout2(x) - - def forward(self, inputs): - # target domain item sequence - all_sample_events_target = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths_target = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - # source domain item sequence - all_sample_events_source = inputs[ - '{}.{}.ids'.format(self._sequence_prefix, self._source_domain) - ] # (all_batch_events) - all_sample_lengths_source = inputs[ - '{}.{}.length'.format(self._sequence_prefix, self._source_domain) - ] # (batch_size) - - # sequential model encoder and target domain items embeddings from sequential model - seq_embeddings_target, seq_mask_target = ( - self._apply_sequential_encoder( - all_sample_events_target, - all_sample_lengths_target, - ) - ) # (batch_size, target_seq_len, embedding_dim), (batch_size, target_seq_len) - - # target domain items encoder for graph model - all_final_user_embeddings_target, all_final_item_embeddings_target = ( - self._apply_graph_encoder( - all_sample_events_target, - all_sample_lengths_target, - ) - ) # (num_users + 2, embedding_dim), (num_items + 2, embedding_dim) - # source domain items encoder for graph model - all_final_user_embeddings_source, all_final_item_embeddings_source = ( - self._apply_graph_encoder( - all_sample_events_source, - all_sample_lengths_source, - ) - ) # (num_users + 2, embedding_dim), (num_items + 2, embedding_dim) - - # target domain items embeddings from graph model - ( - graph_embeddings_target, - graph_item_ego_embeddings_target, - graph_item_mask_target, - ) = self._get_graph_embeddings( - inputs, - self._sequence_prefix, - self._graph_item_embeddings, - all_final_item_embeddings_target, - ) - graph_item_embeddings_target = graph_embeddings_target[ - graph_item_mask_target - ] # (batch_size, target_seq_len, embedding_dim) - # source domain items embeddings from graph model - ( - graph_embeddings_source, - graph_item_ego_embeddings_source, - graph_item_mask_source, - ) = self._get_graph_embeddings( - inputs, - self._sequence_prefix, - self._graph_item_embeddings, - all_final_item_embeddings_source, - ) - graph_item_embeddings_source = graph_embeddings_source[ - graph_item_mask_source - ] # (batch_size, source_seq_len, embedding_dim) - - # embeddings + graph_embeddings_target -> cross-attention - # query = embeddings - # keys = graph_embeddings_target - # values = graph_embeddings_target - if self.norm_first: - graph_embeddings_target = graph_embeddings_target + self.norm1( - self._ca_block( - q=seq_embeddings_target, - k=graph_embeddings_target, - v=graph_embeddings_target, - attn_mask=None, - key_padding_mask=~graph_item_mask_target, - ), - ) # (batch_size, target_seq_len, embedding_dim) - graph_embeddings_target = graph_embeddings_target + self.norm2( - self._ff_block(graph_embeddings_target), - ) - else: - graph_embeddings_target = self.norm1( - graph_embeddings_target - + self._ca_block( - q=seq_embeddings_target, - k=graph_embeddings_target, - v=graph_embeddings_target, - attn_mask=None, - key_padding_mask=~graph_item_mask_target, - ), - ) # (batch_size, target_seq_len, embedding_dim) - graph_embeddings_target = self.norm2( - graph_embeddings_target - + self._ff_block(graph_embeddings_target), - ) - # target-target cross-attention result - mha_embeddings_target = torch.cat( - [seq_embeddings_target, graph_embeddings_target], - dim=-1, - ) - mha_embeddings_target = self._mha_output_projection( - mha_embeddings_target, - ) # (batch_size, target_seq_len, embedding_dim) - - # embeddings + graph_embeddings_source -> cross-attention - # query = embeddings - # keys = graph_embeddings_source - # values = graph_embeddings_source - if self.norm_first: - graph_embeddings_source = graph_embeddings_source + self.norm1( - self._ca_block( - q=seq_embeddings_target, - k=graph_embeddings_source, - v=graph_embeddings_source, - attn_mask=None, - key_padding_mask=~graph_item_mask_source, - ), - ) # (batch_size, seq_len, embedding_dim) - graph_embeddings_source = graph_embeddings_source + self.norm2( - self._ff_block(graph_embeddings_source), - ) - else: - graph_embeddings_source = self.norm1( - graph_embeddings_source - + self._ca_block( - q=seq_embeddings_target, - k=graph_embeddings_source, - v=graph_embeddings_source, - attn_mask=None, - key_padding_mask=~graph_item_mask_source, - ), - ) # (batch_size, seq_len, embedding_dim) - graph_embeddings_source = self.norm2( - graph_embeddings_source - + self._ff_block(graph_embeddings_source), - ) - # source-target cross-attention result - mha_embeddings_source = torch.cat( - [seq_embeddings_target, graph_embeddings_source], - dim=-1, - ) - mha_embeddings_source = self._mha_output_projection( - mha_embeddings_source, - ) # (batch_size, seq_len, embedding_dim) - - if self.training: # training mode - # sequential part - all_positive_sample_events = inputs[ - '{}.ids'.format(self._positive_prefix) - ] # (all_batch_events) - all_negative_sample_events = inputs[ - '{}.ids'.format(self._negative_prefix) - ] # (all_batch_events) - - all_sample_embeddings = seq_embeddings_target[ - seq_mask_target - ] # (all_batch_events, embedding_dim) - all_positive_sample_embeddings = self._item_embeddings( - all_positive_sample_events, - ) # (all_batch_events, embedding_dim) - all_negative_sample_embeddings = self._item_embeddings( - all_negative_sample_events, - ) # (all_batch_events, embedding_dim) - - # graph part, target domain item embeddings - graph_positive_embeddings_target, _, graph_positive_mask_target = ( - self._get_graph_embeddings( - inputs, - self._positive_prefix, - self._graph_item_embeddings, - all_final_item_embeddings_target, - ) - ) - graph_negative_embeddings_target, _, graph_negative_mask_target = ( - self._get_graph_embeddings( - inputs, - self._negative_prefix, - self._graph_item_embeddings, - all_final_item_embeddings_target, - ) - ) - # b - batch_size, s - seq_len, d - embedding_dim - graph_positive_scores_target = torch.einsum( - 'bd,bsd->bs', - graph_item_embeddings_target, - graph_positive_embeddings_target, - ) # (batch_size, target_seq_len) - graph_negative_scores_target = torch.einsum( - 'bd,bsd->bs', - graph_item_embeddings_target, - graph_negative_embeddings_target, - ) # (batch_size, target_seq_len) - graph_positive_scores_target = graph_positive_scores_target[ - graph_positive_mask_target - ] # (all_batch_events) - graph_negative_scores_target = graph_negative_scores_target[ - graph_negative_mask_target - ] # (all_batch_events) - - # graph part, source domain item embeddings - graph_positive_embeddings_source, _, graph_positive_mask_source = ( - self._get_graph_embeddings( - inputs, - self._positive_prefix, - self._graph_item_embeddings, - all_final_item_embeddings_source, - ) - ) - graph_negative_embeddings_source, _, graph_negative_mask_source = ( - self._get_graph_embeddings( - inputs, - self._negative_prefix, - self._graph_item_embeddings, - all_final_item_embeddings_source, - ) - ) - # b - batch_size, s - seq_len, d - embedding_dim - graph_positive_scores_source = torch.einsum( - 'bd,bsd->bs', - graph_item_embeddings_source, - graph_positive_embeddings_source, - ) # (batch_size, source_seq_len) - graph_negative_scores_source = torch.einsum( - 'bd,bsd->bs', - graph_item_embeddings_source, - graph_negative_embeddings_source, - ) # (batch_size, source_seq_len) - graph_positive_scores_source = graph_positive_scores_source[ - graph_positive_mask_source - ] # (all_batch_events) - graph_negative_scores_source = graph_negative_scores_source[ - graph_negative_mask_source - ] # (all_batch_events) - - # mha part - mha_all_sample_embeddings_target = mha_embeddings_target[ - seq_mask_target - ] # (all_batch_events, embedding_dim) - mha_all_sample_embeddings_source = mha_embeddings_source[ - seq_mask_target - ] # (all_batch_events, embedding_dim) - - return { - # sequential part - # target domain item embeddings - 'current_embeddings': all_sample_embeddings, - 'positive_embeddings': all_positive_sample_embeddings, - 'negative_embeddings': all_negative_sample_embeddings, - # graph part - # target domain item embeddings - 'graph_positive_embeddings_target': graph_positive_embeddings_target[ - graph_positive_mask_target - ], - 'graph_negative_embeddings_target': graph_negative_embeddings_target[ - graph_negative_mask_target - ], - 'graph_positive_scores_target': graph_positive_scores_target, - 'graph_negative_scores_target': graph_negative_scores_target, - 'graph_item_embeddings_target': graph_item_embeddings_target, - # source domain item embeddings - 'graph_positive_embeddings_source': graph_positive_embeddings_source[ - graph_positive_mask_source - ], - 'graph_negative_embeddings_source': graph_negative_embeddings_source[ - graph_negative_mask_source - ], - 'graph_positive_scores_source': graph_positive_scores_source, - 'graph_negative_scores_source': graph_negative_scores_source, - 'graph_item_embeddings_source': graph_item_embeddings_source, - # mha part - # target domain item embeddings - 'mha_embeddings_target': mha_all_sample_embeddings_target, - 'mha_positive_embeddings_target': all_positive_sample_embeddings, - 'mha_negative_embeddings_target': all_negative_sample_embeddings, - # source domain item embeddings - 'mha_embeddings_source': mha_all_sample_embeddings_source, - 'mha_positive_embeddings_source': all_positive_sample_embeddings, - 'mha_negative_embeddings_source': all_negative_sample_embeddings, - } - else: # eval mode - seq_last_embeddings_target = self._get_last_embedding( - seq_embeddings_target, - seq_mask_target, - ) # (batch_size, embedding_dim) - mha_last_embeddings_target = self._get_last_embedding( - mha_embeddings_target, - seq_mask_target, - ) # (batch_size, embedding_dim) - mha_last_embeddings_source = self._get_last_embedding( - mha_embeddings_source, - seq_mask_target, - ) # (batch_size, embedding_dim) - - aggregated_last_embeddings = torch.maximum( - seq_last_embeddings_target, - torch.maximum( - mha_last_embeddings_target, - mha_last_embeddings_source, - ), - ) # (batch_size, embedding_dim) - - # b - batch_size, n - num_candidates, d - embedding_dim - candidate_scores = torch.einsum( - 'bd,nd->bn', - aggregated_last_embeddings, - self._item_embeddings.weight, - ) # (batch_size, num_items + 2) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - if '{}.ids'.format(self._candidate_prefix) in inputs: - candidate_events = inputs[ - '{}.ids'.format(self._candidate_prefix) - ] # (all_batch_candidates) - candidate_lengths = inputs[ - '{}.length'.format(self._candidate_prefix) - ] # (batch_size) - - batch_size = candidate_lengths.shape[0] - num_candidates = candidate_lengths[0] - - candidate_scores = torch.gather( - input=candidate_scores, - dim=1, - index=torch.reshape( - candidate_events, - [batch_size, num_candidates], - ), - ) # (batch_size, num_candidates) - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20), (batch_size, 20) - - return indices diff --git a/src/irec/models/lightgcn.py b/src/irec/models/lightgcn.py deleted file mode 100644 index fd22f7eb..00000000 --- a/src/irec/models/lightgcn.py +++ /dev/null @@ -1,227 +0,0 @@ -from .base import TorchModel - -from irec.utils import create_masked_tensor, DEVICE - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class LightGCNModel(TorchModel, config_name='light_gcn'): - def __init__( - self, - user_prefix, - positive_prefix, - graph, - num_users, - num_items, - embedding_dim, - num_layers, - dropout=0.0, - initializer_range=0.02, - ): - super().__init__() - self._user_prefix = user_prefix - self._positive_prefix = positive_prefix - self._graph = graph - self._num_users = num_users - self._num_items = num_items - self._embedding_dim = embedding_dim - self._num_layers = num_layers - self._dropout_rate = dropout - - self._user_embeddings = nn.Embedding( - num_embeddings=self._num_users + 2, - embedding_dim=self._embedding_dim, - ) - - self._item_embeddings = nn.Embedding( - num_embeddings=self._num_items + 2, - embedding_dim=self._embedding_dim, - ) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - user_prefix=config['user_prefix'], - positive_prefix=config['positive_prefix'], - graph=kwargs['graph'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - embedding_dim=config['embedding_dim'], - num_layers=config['num_layers'], - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - ) - - def _apply_graph_encoder(self): - ego_embeddings = torch.cat( - (self._user_embeddings.weight, self._item_embeddings.weight), - dim=0, - ) - all_embeddings = [ego_embeddings] - - if self._dropout_rate > 0: # drop some edges - if self.training: # training_mode - size = self._graph.size() - index = self._graph.indices().t() - values = self._graph.values() - random_index = torch.rand(len(values)) + ( - 1 - self._dropout_rate - ) - random_index = random_index.int().bool() - index = index[random_index] - values = values[random_index] / (1 - self._dropout_rate) - graph_dropped = torch.sparse.FloatTensor( - index.t(), - values, - size, - ) - else: # eval mode - graph_dropped = self._graph - else: - graph_dropped = self._graph - - for i in range(self._num_layers): - ego_embeddings = torch.sparse.mm(graph_dropped, ego_embeddings) - norm_embeddings = F.normalize(ego_embeddings, p=2, dim=1) - all_embeddings += [norm_embeddings] - - all_embeddings = torch.cat(all_embeddings, dim=-1) - user_final_embeddings, item_final_embeddings = torch.split( - all_embeddings, - [self._num_users + 2, self._num_items + 2], - ) - - return user_final_embeddings, item_final_embeddings - - def _get_embeddings( - self, - inputs, - prefix, - ego_embeddings, - final_embeddings, - ): - ids = inputs['{}.ids'.format(prefix)] # (all_batch_events) - lengths = inputs['{}.length'.format(prefix)] # (batch_size) - - final_embeddings = final_embeddings[ - ids - ] # (all_batch_events, embedding_dim) - ego_embeddings = ego_embeddings( - ids, - ) # (all_batch_events, embedding_dim) - - padded_embeddings, mask = create_masked_tensor( - final_embeddings, - lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - padded_ego_embeddings, ego_mask = create_masked_tensor( - ego_embeddings, - lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - assert torch.all(mask == ego_mask) - - return padded_embeddings, padded_ego_embeddings, mask - - def forward(self, inputs): - all_final_user_embeddings, all_final_item_embeddings = ( - self._apply_graph_encoder() - ) # (num_users + 2, embedding_dim), (num_items + 2, embedding_dim) - - user_embeddings, user_ego_embeddings, user_mask = self._get_embeddings( - inputs, - self._user_prefix, - self._user_embeddings, - all_final_user_embeddings, - ) - user_embeddings = user_embeddings[ - user_mask - ] # (batch_size, embedding_dim) - - if self.training: # training mode - positive_item_ids = inputs[ - '{}.ids'.format(self._positive_prefix) - ] # (all_batch_events) - positive_item_lengths = inputs[ - '{}.length'.format(self._positive_prefix) - ] # (batch_size) - - batch_size = positive_item_lengths.shape[0] - max_sequence_length = positive_item_lengths.max().item() - - mask = ( - torch.arange(end=max_sequence_length, device=DEVICE)[ - None - ].tile([batch_size, 1]) - < positive_item_lengths[:, None] - ) # (batch_size, max_seq_len) - - positive_user_ids = ( - torch.arange(batch_size, device=DEVICE)[None] - .tile([max_sequence_length, 1]) - .T - ) # (batch_size, max_seq_len) - positive_user_ids = positive_user_ids[mask] # (all_batch_items) - user_embeddings = user_embeddings[ - positive_user_ids - ] # (all_batch_items, embedding_dim) - - all_scores = torch.einsum( - 'ad,nd->an', - user_embeddings, - all_final_item_embeddings, - ) # (all_batch_items, num_items + 2) - - negative_mask = torch.zeros( - self._num_items + 2, - dtype=torch.bool, - device=DEVICE, - ) # (num_items + 2) - negative_mask[positive_item_ids] = 1 - - positive_scores = torch.gather( - input=all_scores, - dim=1, - index=positive_item_ids[..., None], - ) # (all_batch_items, 1) - - all_scores = torch.scatter_add( - input=all_scores, - dim=1, - index=positive_item_ids[..., None], - src=torch.ones_like(positive_item_ids[..., None]).float(), - ) # (all_batch_items, num_items + 2) - - return { - 'positive_scores': positive_scores, - 'negative_scores': all_scores, - 'item_embeddings': torch.cat( - ( - self._user_embeddings.weight, - self._item_embeddings.weight, - ), - dim=0, - ), - } - else: # eval mode - candidate_scores = torch.einsum( - 'bd,nd->bn', - user_embeddings, - all_final_item_embeddings, - ) # (batch_size, num_items + 2) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/models/mclsr.py b/src/irec/models/mclsr.py deleted file mode 100644 index fc8c6ef8..00000000 --- a/src/irec/models/mclsr.py +++ /dev/null @@ -1,436 +0,0 @@ -from .base import TorchModel - -import torch -import torch.nn as nn - -from irec.utils import create_masked_tensor - - -class MCLSRModel(TorchModel, config_name='mclsr'): - def __init__( - self, - sequence_prefix, - user_prefix, - labels_prefix, - negatives_prefix, - candidate_prefix, - num_users, - num_items, - max_sequence_length, - embedding_dim, - num_graph_layers, - common_graph, - user_graph, - item_graph, - dropout=0.0, - layer_norm_eps=1e-5, - graph_dropout=0.0, - alpha=0.5, - initializer_range=0.02, - ): - super().__init__() - self._sequence_prefix = sequence_prefix - self._user_prefix = user_prefix - self._labels_prefix = labels_prefix - self._negatives_prefix = negatives_prefix - self._candidate_prefix = candidate_prefix - - self._num_users = num_users - self._num_items = num_items - - self._embedding_dim = embedding_dim - - self._num_graph_layers = num_graph_layers - self._graph_dropout = graph_dropout - - self._alpha = alpha - - self._graph = common_graph - self._user_graph = user_graph - self._item_graph = item_graph - - self._item_embeddings = nn.Embedding( - num_embeddings=num_items - + 2, # add zero embedding + mask embedding - embedding_dim=embedding_dim, - ) - self._position_embeddings = nn.Embedding( - num_embeddings=max_sequence_length - + 1, # in order to include `max_sequence_length` value - embedding_dim=embedding_dim, - ) - - self._user_embeddings = nn.Embedding( - num_embeddings=num_users - + 2, # add zero embedding + mask embedding - embedding_dim=embedding_dim, - ) - - self._layernorm = nn.LayerNorm(embedding_dim, eps=layer_norm_eps) - self._dropout = nn.Dropout(dropout) - - # Current interest learning - self._current_interest_learning_encoder = nn.Sequential( - nn.Linear( - in_features=embedding_dim, - out_features=4 * embedding_dim, - bias=False, - ), - nn.Tanh(), - nn.Linear( - in_features=4 * embedding_dim, - out_features=1, - bias=False, - ), - ) - - # General interest learning - self._general_interest_learning_encoder = nn.Sequential( - nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - bias=False, - ), - nn.Tanh(), - ) - - # Cross-view contrastive learning - self._sequential_projector = nn.Sequential( - nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - bias=True, - ), - nn.ELU(), - nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - bias=True, - ), - ) - self._graph_projector = nn.Sequential( - nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - bias=True, - ), - nn.ELU(), - nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - bias=True, - ), - ) - - self._user_projection = nn.Sequential( - nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - bias=True, - ), - nn.ELU(), - nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - bias=True, - ), - ) - - self._item_projection = nn.Sequential( - nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - bias=True, - ), - nn.ELU(), - nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - bias=True, - ), - ) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - user_prefix=config['user_prefix'], - labels_prefix=config['labels_prefix'], - negatives_prefix=config.get('negatives_prefix', 'negatives'), - candidate_prefix=config['candidate_prefix'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_graph_layers=config['num_graph_layers'], - common_graph=kwargs['graph'], - user_graph=kwargs['user_graph'], - item_graph=kwargs['item_graph'], - dropout=config.get('dropout', 0.0), - layer_norm_eps=config.get('layer_norm_eps', 1e-5), - graph_dropout=config.get('graph_dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - ) - - def _apply_graph_encoder(self, embeddings, graph, use_mean=False): - assert self.training # Here we use graph only in training_mode - - size = graph.size() - index = graph.indices().t() - values = graph.values() - dropout_mask = torch.rand(len(values)) + self._graph_dropout - dropout_mask = dropout_mask.int().bool() - index = index[~dropout_mask] - values = values[~dropout_mask] / (1.0 - self._graph_dropout) - graph_dropped = torch.sparse.FloatTensor(index.t(), values, size) - - all_embeddings = [embeddings] - for _ in range(self._num_graph_layers): - # import code; code.interact(local=locals()) - new_embeddings = torch.sparse.mm(graph_dropped, all_embeddings[-1]) - all_embeddings.append(new_embeddings) - - if use_mean: - all_embeddings = torch.stack(all_embeddings, dim=1) - return torch.mean(all_embeddings, dim=1) - else: - return all_embeddings[-1] - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - user_ids = inputs['{}.ids'.format(self._user_prefix)] # (batch_size) - - embeddings = self._item_embeddings( - all_sample_events, - ) # (all_batch_events, embedding_dim) - embeddings, mask = create_masked_tensor( - data=embeddings, - lengths=all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim) - - batch_size = mask.shape[0] - seq_len = mask.shape[1] - - # Current interest learning - # 1) get embeddings with positions - positions = ( - torch.arange( - start=seq_len - 1, - end=-1, - step=-1, - device=mask.device, - )[None] - .tile([batch_size, 1]) - .long() - ) # (batch_size, seq_len) - positions_mask = ( - positions < all_sample_lengths[:, None] - ) # (batch_size, max_seq_len) - - positions = positions[positions_mask] # (all_batch_events) - position_embeddings = self._position_embeddings( - positions, - ) # (all_batch_events, embedding_dim) - position_embeddings, _ = create_masked_tensor( - data=position_embeddings, - lengths=all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim) - assert torch.allclose(position_embeddings[~mask], embeddings[~mask]) - - positioned_embeddings = ( - embeddings + position_embeddings - ) # (batch_size, seq_len, embedding_dim) - - positioned_embeddings = self._layernorm( - positioned_embeddings, - ) # (batch_size, seq_len, embedding_dim) - positioned_embeddings = self._dropout( - positioned_embeddings, - ) # (batch_size, seq_len, embedding_dim) - positioned_embeddings[~mask] = 0 - - # formula 2 - sequential_attention_matrix = self._current_interest_learning_encoder( - positioned_embeddings, # E_u,p - ).squeeze() # (batch_size, seq_len) - - sequential_attention_matrix[~mask] = -torch.inf - sequential_attention_matrix = torch.softmax( - sequential_attention_matrix, - dim=1, - ) # (batch_size, seq_len) - - # formula 3 - sequential_representation = torch.einsum( - 'bs,bsd->bd', - sequential_attention_matrix, # A^s - embeddings, - ) # (batch_size, embedding_dim) - - if self.training: - # general interest - # formula 4 - all_init_embeddings = torch.cat([self._user_embeddings.weight, - self._item_embeddings.weight], - dim=0) - all_graph_embeddings = self._apply_graph_encoder(embeddings=all_init_embeddings, - graph=self._graph) - - common_graph_user_embs_all, common_graph_item_embs_all = torch.split( - all_graph_embeddings, [self._num_users + 2, self._num_items + 2] - ) - common_graph_user_embs_batch = common_graph_user_embs_all[user_ids] - common_graph_item_embs_batch, _ = create_masked_tensor( - data=common_graph_item_embs_all[all_sample_events], - lengths=all_sample_lengths - ) - - # formula 5: A_c = softmax(tanh(W_3 * h_u,uv) * (E_u,uv)^T) - graph_attention_matrix = torch.einsum('bd,bsd->bs', - self._general_interest_learning_encoder - (common_graph_user_embs_batch), - common_graph_item_embs_batch) - graph_attention_matrix[~mask] = -torch.inf - graph_attention_matrix = torch.softmax(graph_attention_matrix, dim=1) - - # formula 6: I_c = A_c * E_u,uv - original_graph_representation = torch.einsum('bs,bsd->bd', - graph_attention_matrix, - common_graph_item_embs_batch) - original_sequential_representation = sequential_representation - - # formula 13: I_comb = alpha * I_s + (1 - alpha) * I_c - # L_P (Downstream Loss) - combined_representation = (self._alpha * original_sequential_representation + - (1 - self._alpha) * original_graph_representation) - labels = inputs['{}.ids'.format(self._labels_prefix)] - labels_embeddings = self._item_embeddings(labels) - - # formula 7 - # L_IL (Interest-level CL) - sequential_representation_proj = self._sequential_projector(original_sequential_representation) - graph_representation_proj = self._graph_projector(original_graph_representation) - - # formula 9: H_u,uu = GraphEncoder(H_u, G_uu) - # L_UC (User-level CL) - user_graph_user_embs_all = self._apply_graph_encoder(embeddings=self._user_embeddings.weight, - graph=self._user_graph) - user_graph_user_embs_batch = user_graph_user_embs_all[user_ids] - - # formula 10 - # T_f,uu = MLP(H_u,uu) и T_f,uv = MLP(H_u,uv) - user_graph_user_embeddings_proj = self._user_projection(user_graph_user_embs_batch) - common_graph_user_embeddings_proj = self._user_projection(common_graph_user_embs_batch) - - # item level CL - common_graph_items_flat = common_graph_item_embs_batch[mask] - - item_graph_items_all = self._apply_graph_encoder(embeddings=self._item_embeddings.weight, - graph=self._item_graph) - item_graph_items_flat = item_graph_items_all[all_sample_events] - - unique_item_ids, inverse_indices = torch.unique(all_sample_events, - return_inverse=True) - - try: - from torch_scatter import scatter_mean - except ImportError: - # print("Warning: torch_scatter not found. Using a slower fallback function.") - def scatter_mean(src, index, dim=0, dim_size=None): - out_size = dim_size if dim_size is not None else index.max() + 1 - out = torch.zeros((out_size, src.size(1)), dtype=src.dtype, device=src.device) - counts = torch.bincount(index, minlength=out_size).unsqueeze(-1).clamp(min=1) - return out.scatter_add_(dim, index.unsqueeze(-1).expand_as(src), src) / counts - - num_unique_items = unique_item_ids.shape[0] - - unique_common_graph_items = scatter_mean(common_graph_items_flat, - inverse_indices, dim=0, - dim_size=num_unique_items) - - unique_item_graph_items = scatter_mean(item_graph_items_flat, - inverse_indices, dim=0, - dim_size=num_unique_items) - - # projection for Item-level Feature CL - unique_common_graph_items_proj = self._item_projection(unique_common_graph_items) - unique_item_graph_items_proj = self._item_projection(unique_item_graph_items) - - negative_ids = inputs['{}.ids'.format(self._negatives_prefix)] # (batch_size, num_negatives) - negative_embeddings = self._item_embeddings(negative_ids) # (batch_size, num_negatives, embedding_dim) - - # import code; code.interact(local=locals()) - - return { - # L_P (formula 14) - 'combined_representation': combined_representation, - 'label_representation': labels_embeddings, - - 'negative_representation': negative_embeddings, - - # for L_IL (formula 8) - - 'sequential_representation': sequential_representation_proj, - 'graph_representation': graph_representation_proj, - - # for L_UC (formula 11) - 'user_graph_user_embeddings': user_graph_user_embeddings_proj, - 'common_graph_user_embeddings': common_graph_user_embeddings_proj, - - # for L_IC - 'item_graph_item_embeddings': unique_item_graph_items_proj, - 'common_graph_item_embeddings': unique_common_graph_items_proj, - } - else: # eval mode - # formula 16: R(u,N) = Top-N((I_s)^T * h_o) - if '{}.ids'.format(self._candidate_prefix) in inputs: - candidate_events = inputs[ - '{}.ids'.format(self._candidate_prefix) - ] # (all_batch_candidates) - candidate_lengths = inputs[ - '{}.length'.format(self._candidate_prefix) - - ] # (batch_size) - - candidate_embeddings = self._item_embeddings( - candidate_events, - ) # (all_batch_candidates, embedding_dim) - - candidate_embeddings, _ = create_masked_tensor( - data=candidate_embeddings, - lengths=candidate_lengths, - ) # (batch_size, num_candidates, embedding_dim) - - candidate_scores = torch.einsum( - 'bd,bnd->bn', - sequential_representation, # I_s - candidate_embeddings, # h_o (and h_k) - ) # (batch_size, num_candidates) - else: - candidate_embeddings = ( - self._item_embeddings.weight - ) # (num_items, embedding_dim) - candidate_scores = torch.einsum( - 'bd,nd->bn', - sequential_representation, # I_s - candidate_embeddings, # all h_v - ) # (batch_size, num_items) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - - values, indices = torch.topk( - candidate_scores, - k=50, - dim=-1, - largest=True, - ) # (batch_size, 100), (batch_size, 100) - - return indices \ No newline at end of file diff --git a/src/irec/models/mrgsrec.py b/src/irec/models/mrgsrec.py deleted file mode 100644 index 9677c612..00000000 --- a/src/irec/models/mrgsrec.py +++ /dev/null @@ -1,144 +0,0 @@ -from .base import TorchModel - -from irec.utils import create_masked_tensor, get_activation_function - -import torch -import torch.nn as nn - - -class MRGSRecModel(TorchModel, config_name='mrgsrec'): - def __init__( - self, - sequence_prefix, - user_prefix, - positive_prefix, - negative_prefix, - candidate_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02, - ): - super().__init__() - self._sequence_prefix = sequence_prefix - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._candidate_prefix = candidate_prefix - - self._num_items = num_items - self._num_heads = num_heads - self._embedding_dim = embedding_dim - - self._item_embeddings = nn.Embedding( - num_embeddings=num_items - + 2, # add zero embedding + mask embedding - embedding_dim=embedding_dim, - ) - self._position_embeddings = nn.Embedding( - num_embeddings=max_sequence_length - + 1, # in order to include `max_sequence_length` value - embedding_dim=embedding_dim, - ) - - self._layernorm = nn.LayerNorm(embedding_dim, eps=layer_norm_eps) - self._dropout = nn.Dropout(dropout) - - transformer_encoder_layer = nn.TransformerEncoderLayer( - d_model=embedding_dim, - nhead=num_heads, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=get_activation_function(activation), - layer_norm_eps=layer_norm_eps, - batch_first=True, - ) - self._encoder = nn.TransformerEncoder( - transformer_encoder_layer, - num_layers, - ) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - negative_prefix=config['negative_prefix'], - candidate_prefix=config['candidate_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get( - 'num_heads', - int(config['embedding_dim'] // 64), - ), - num_layers=config['num_layers'], - dim_feedforward=config.get( - 'dim_feedforward', - 4 * config['embedding_dim'], - ), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - ) - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - embeddings = self._item_embeddings( - all_sample_events, - ) # (all_batch_events, embedding_dim) - - embeddings, mask = create_masked_tensor( - data=embeddings, - lengths=all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - batch_size = mask.shape[0] - seq_len = mask.shape[1] - - positions = ( - torch.arange( - start=seq_len - 1, - end=-1, - step=-1, - device=mask.device, - )[None] - .tile([batch_size, 1]) - .long() - ) # (batch_size, seq_len) - positions_mask = ( - positions < all_sample_lengths[:, None] - ) # (batch_size, max_seq_len) - - positions = positions[positions_mask] # (all_batch_events) - position_embeddings = self._position_embeddings( - positions, - ) # (all_batch_events, embedding_dim) - position_embeddings, _ = create_masked_tensor( - data=position_embeddings, - lengths=all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim) - assert torch.allclose(position_embeddings[~mask], embeddings[~mask]) - - embeddings = ( - embeddings + position_embeddings - ) # (batch_size, seq_len, embedding_dim) - - embeddings = self._layernorm( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = self._dropout( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - - embeddings[~mask] = 0 diff --git a/src/irec/models/ngcf.py b/src/irec/models/ngcf.py deleted file mode 100644 index 2ee1b5fc..00000000 --- a/src/irec/models/ngcf.py +++ /dev/null @@ -1,244 +0,0 @@ -from .base import TorchModel - -from irec.utils import create_masked_tensor, DEVICE - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class NgcfModel(TorchModel, config_name='ngcf'): - def __init__( - self, - user_prefix, - positive_prefix, - graph, - num_users, - num_items, - embedding_dim, - num_layers, - dropout=0.0, - initializer_range=0.02, - ): - super().__init__() - self._user_prefix = user_prefix - self._positive_prefix = positive_prefix - self._graph = graph - self._num_users = num_users - self._num_items = num_items - self._embedding_dim = embedding_dim - self._num_layers = num_layers - self._dropout_rate = dropout - - self.dropout_list = nn.ModuleList() - self.GC_Linear_list = nn.ModuleList() - self.Bi_Linear_list = nn.ModuleList() - for i in range(self._num_layers): - self.dropout_list.append(nn.Dropout(dropout)) - self.GC_Linear_list.append(nn.Linear(embedding_dim, embedding_dim)) - self.Bi_Linear_list.append(nn.Linear(embedding_dim, embedding_dim)) - - self._user_embeddings = nn.Embedding( - num_embeddings=self._num_users + 2, - embedding_dim=self._embedding_dim, - ) - - self._item_embeddings = nn.Embedding( - num_embeddings=self._num_items + 2, - embedding_dim=self._embedding_dim, - ) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - user_prefix=config['user_prefix'], - positive_prefix=config['positive_prefix'], - graph=kwargs['graph'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - embedding_dim=config['embedding_dim'], - num_layers=config['num_layers'], - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - ) - - def _get_embeddings( - self, - inputs, - prefix, - ego_embeddings, - final_embeddings, - ): - ids = inputs['{}.ids'.format(prefix)] # (all_batch_events) - lengths = inputs['{}.length'.format(prefix)] # (batch_size) - - final_embeddings = final_embeddings[ - ids - ] # (all_batch_events, embedding_dim) - ego_embeddings = ego_embeddings( - ids, - ) # (all_batch_events, embedding_dim) - - padded_embeddings, mask = create_masked_tensor( - final_embeddings, - lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - padded_ego_embeddings, ego_mask = create_masked_tensor( - ego_embeddings, - lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - assert torch.all(mask == ego_mask) - - return padded_embeddings, padded_ego_embeddings, mask - - def _apply_graph_encoder(self): - ego_embeddings = torch.cat( - (self._user_embeddings.weight, self._item_embeddings.weight), - dim=0, - ) - all_embeddings = [ego_embeddings] - - if self._dropout_rate > 0: # drop some edges - if self.training: # training_mode - size = self._graph.size() - index = self._graph.indices().t() - values = self._graph.values() - random_index = torch.rand(len(values)) + ( - 1 - self._dropout_rate - ) - random_index = random_index.int().bool() - index = index[random_index] - values = values[random_index] / (1 - self._dropout_rate) - graph_dropped = torch.sparse.FloatTensor( - index.t(), - values, - size, - ) - else: # eval mode - graph_dropped = self._graph - else: - graph_dropped = self._graph - - for i in range(self._num_layers): - side_embeddings = torch.sparse.mm(graph_dropped, ego_embeddings) - sum_embeddings = F.leaky_relu( - self.GC_Linear_list[i](side_embeddings), - ) - bi_embeddings = torch.mul(ego_embeddings, side_embeddings) - bi_embeddings = F.leaky_relu(self.Bi_Linear_list[i](bi_embeddings)) - ego_embeddings = sum_embeddings + bi_embeddings - ego_embeddings = self.dropout_list[i](ego_embeddings) - - norm_embeddings = F.normalize(ego_embeddings, p=2, dim=1) - all_embeddings += [norm_embeddings] - - all_embeddings = torch.cat(all_embeddings, dim=-1) - user_final_embeddings, item_final_embeddings = torch.split( - all_embeddings, - [self._num_users + 2, self._num_items + 2], - ) - - return user_final_embeddings, item_final_embeddings - - def forward(self, inputs): - all_final_user_embeddings, all_final_item_embeddings = ( - self._apply_graph_encoder() - ) # (num_users + 2, embedding_dim), (num_items + 2, embedding_dim) - - user_embeddings, user_ego_embeddings, user_mask = self._get_embeddings( - inputs, - self._user_prefix, - self._user_embeddings, - all_final_user_embeddings, - ) - user_embeddings = user_embeddings[ - user_mask - ] # (all_batch_events, embedding_dim) - - if self.training: # training mode - positive_item_ids = inputs[ - '{}.ids'.format(self._positive_prefix) - ] # (all_batch_events) - positive_item_lengths = inputs[ - '{}.length'.format(self._positive_prefix) - ] # (batch_size) - - batch_size = positive_item_lengths.shape[0] - max_sequence_length = positive_item_lengths.max().item() - - mask = ( - torch.arange(end=max_sequence_length, device=DEVICE)[ - None - ].tile([batch_size, 1]) - < positive_item_lengths[:, None] - ) # (batch_size, max_seq_len) - - positive_user_ids = ( - torch.arange(batch_size, device=DEVICE)[None] - .tile([max_sequence_length, 1]) - .T - ) # (batch_size, max_seq_len) - positive_user_ids = positive_user_ids[mask] # (all_batch_items) - user_embeddings = user_embeddings[ - positive_user_ids - ] # (all_batch_items, embedding_dim) - - all_scores = torch.einsum( - 'ad,nd->an', - user_embeddings, - all_final_item_embeddings, - ) # (all_batch_items, num_items + 2) - - negative_mask = torch.zeros( - self._num_items + 2, - dtype=torch.bool, - device=DEVICE, - ) # (num_items + 2) - negative_mask[positive_item_ids] = 1 - - positive_scores = torch.gather( - input=all_scores, - dim=1, - index=positive_item_ids[..., None], - ) # (all_batch_items, 1) - - all_scores = torch.scatter_add( - input=all_scores, - dim=1, - index=positive_item_ids[..., None], - src=torch.ones_like(positive_item_ids[..., None]).float(), - ) # (all_batch_items, num_items + 2) - - return { - 'positive_scores': positive_scores, - 'negative_scores': all_scores, - 'item_embeddings': torch.cat( - ( - self._user_embeddings.weight, - self._item_embeddings.weight, - ), - dim=0, - ), - } - else: # eval mode - # b - batch_size, n - num_candidates, d - embedding_dim - candidate_scores = torch.einsum( - 'bd,nd->bn', - user_embeddings, - all_final_item_embeddings, - ) # (batch_size, num_items + 2) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/models/old_rqvae.py b/src/irec/models/old_rqvae.py new file mode 100644 index 00000000..86c1d583 --- /dev/null +++ b/src/irec/models/old_rqvae.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, input_dim, output_dim, dropout=0.1): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + + self.norm = nn.LayerNorm(input_dim) + self.layer = nn.Linear(input_dim, output_dim) + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + embedding = x + embedding = self.norm(embedding) + embedding = self.layer(embedding) + embedding = self.act(embedding) + embedding = self.dropout(embedding) + + if self.input_dim == self.output_dim: + return embedding + x + return embedding + + +class Tower(nn.Module): + def __init__(self, dims, dropout): + super().__init__() + self.layers = nn.ModuleList() + for i in range(len(dims) - 1): + self.layers.append(ResidualBlock(dims[i], dims[i + 1], dropout)) + + def forward(self, x): + embedding = x + for layer in self.layers: + embedding = layer(embedding) + return embedding + + +class RQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + layers, + dropout_prob=0.0, + beta=0.25, + quant_loss_weight=1.0, + + ): + super().__init__() + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.beta = beta + self.quant_loss_weight = quant_loss_weight + + self.layers = layers + self.dropout_prob = dropout_prob + + self.encoder_layer_dims = [self.input_dim] + self.layers + [self.embedding_dim] + self.decoder_layer_dims = self.encoder_layer_dims[::-1] + + # TODO add inizialisation with AE + self.encoder = Tower( + dims=self.encoder_layer_dims, + dropout=self.dropout_prob + ) + self.decoder = Tower( + dims=self.decoder_layer_dims, + dropout=self.dropout_prob + ) + + self.codebooks = torch.nn.ParameterList() + for _ in range(num_codebooks): + cb = torch.FloatTensor(codebook_size, embedding_dim) + self.codebooks.append(cb) + + @staticmethod + def make_encoding_tower(d1, d2, bias=False): + return torch.nn.Sequential( + nn.LayerNorm(d1), + nn.Linear(d1, d1), + nn.GELU(), + torch.nn.Linear(d1, d2, bias=bias) + ) + + @staticmethod + def get_codebook_indices(remainder, codebook): + dist = torch.cdist(remainder, codebook) + return dist.argmin(dim=-1) + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + clusters.append(codebook_indices) + + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = F.mse_loss(embeddings_restored, inputs['embedding']) + loss = (recon_loss + self.quant_loss_weight * rqvae_loss).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } diff --git a/src/irec/models/pop.py b/src/irec/models/pop.py deleted file mode 100644 index 82faa6e4..00000000 --- a/src/irec/models/pop.py +++ /dev/null @@ -1,46 +0,0 @@ -from .base import BaseModel - -import torch - - -class PopModel(BaseModel, config_name='pop'): - def __init__(self, label_prefix, counts_prefix, num_items): - self._label_prefix = label_prefix - self._counts_prefix = counts_prefix - self._num_items = num_items - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - label_prefix=config['label_prefix'], - counts_prefix=config['counts_prefix'], - num_items=kwargs['num_items'], - ) - - def __call__(self, inputs): - candidate_counts = inputs[ - '{}.ids'.format(self._counts_prefix) - ] # (all_batch_candidates) - candidate_counts_lengths = inputs[ - '{}.length'.format(self._counts_prefix) - ] # (batch_size) - batch_size = candidate_counts_lengths.shape[0] - - candidate_scores = torch.reshape( - candidate_counts, - shape=(batch_size, self._num_items + 2), - ).float() # (batch_size, num_items) - candidate_scores[:, 0] = -torch.inf # zero (padding) token - candidate_scores[ - :, - self._num_items + 1 :, - ] = -torch.inf # all not real items-related things - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/models/pure_mf.py b/src/irec/models/pure_mf.py deleted file mode 100644 index 8c4c6a77..00000000 --- a/src/irec/models/pure_mf.py +++ /dev/null @@ -1,131 +0,0 @@ -from .base import TorchModel - -import torch -import torch.nn as nn - -from irec.utils import create_masked_tensor - - -class PureMF(TorchModel, config_name='pure_mf'): - def __init__( - self, - user_prefix, - positive_prefix, - negative_prefix, - num_users, - num_items, - embedding_dim, - initializer_range, - ): - super().__init__() - - self._user_prefix = user_prefix - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - - self._num_users = num_users - self._num_items = num_items - self._embedding_dim = embedding_dim - - self._user_embeddings = nn.Embedding( - num_embeddings=self._num_users + 2, - embedding_dim=self._embedding_dim, - ) - - self._item_embeddings = nn.Embedding( - num_embeddings=self._num_items + 2, - embedding_dim=self._embedding_dim, - ) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - user_prefix=config['user_prefix'], - positive_prefix=config['positive_prefix'], - negative_prefix=config['negative_prefix'], - num_users=kwargs['num_users'], - num_items=kwargs['num_items'], - embedding_dim=config['embedding_dim'], - initializer_range=config.get('initializer_range', 0.02), - ) - - def forward(self, inputs): - user_ids = inputs['{}.ids'.format(self._user_prefix)] # (batch_size) - user_embeddings = self._user_embeddings( - user_ids, - ) # (batch_size, embedding_dim) - - if self.training: # training mode - all_positive = inputs[ - '{}.ids'.format(self._positive_prefix) - ] # (all_batch_events) - all_positive_embeddings = self._item_embeddings( - all_positive, - ) # (all_batch_events, embedding_dim) - positive_lengths = inputs[ - '{}.length'.format(self._positive_prefix) - ] # (batch_size) - - all_negative = inputs[ - '{}.ids'.format(self._negative_prefix) - ] # (all_batch_events) - all_negative_embeddings = self._item_embeddings( - all_negative, - ) # (all_batch_events, embedding_dim) - negative_lengths = inputs[ - '{}.length'.format(self._negative_prefix) - ] # (batch_size) - - positive_embeddings, positive_mask = create_masked_tensor( - all_positive_embeddings, - positive_lengths, - ) - negative_embeddings, negative_mask = create_masked_tensor( - all_negative_embeddings, - negative_lengths, - ) - - positive_scores = torch.einsum( - 'bd,bsd->bs', - user_embeddings, - positive_embeddings, - ) # (batch_size, seq_len) - negative_scores = torch.einsum( - 'bd,bsd->bs', - user_embeddings, - negative_embeddings, - ) # (batch_size, seq_len) - - positive_scores = positive_scores[ - positive_mask - ] # (all_batch_events) - negative_scores = negative_scores[ - negative_mask - ] # (all_batch_events) - - return { - 'positive_scores': positive_scores, - 'negative_scores': negative_scores, - } - else: - candidate_embeddings = ( - self._item_embeddings.weight - ) # (num_items, embedding_dim) - candidate_scores = torch.einsum( - 'bd,nd->bn', - user_embeddings, - candidate_embeddings, - ) # (batch_size, num_items) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/models/pure_svd.py b/src/irec/models/pure_svd.py deleted file mode 100644 index 4a03731b..00000000 --- a/src/irec/models/pure_svd.py +++ /dev/null @@ -1,14 +0,0 @@ -from .base import BaseModel - - -class SVDModel(BaseModel, config_name='pure_svd'): - def __init__(self, rank): - super().__init__() - - self._rank = rank - self._method = 'PureSVD' - self._factors = {} - - @property - def rank(self): - return self._rank diff --git a/src/irec/models/random.py b/src/irec/models/random.py deleted file mode 100644 index 9a595c3d..00000000 --- a/src/irec/models/random.py +++ /dev/null @@ -1,41 +0,0 @@ -from .base import BaseModel - -import torch - - -class RandomModel(BaseModel, config_name='random'): - def __init__(self, label_prefix, num_items): - self._label_prefix = label_prefix - self._num_items = num_items - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - label_prefix=config['label_prefix'], - num_items=kwargs['num_items'], - ) - - def __call__(self, inputs): - labels_lengths = inputs[ - '{}.length'.format(self._label_prefix) - ] # (batch_size) - batch_size = labels_lengths.shape[0] - - candidate_scores = torch.rand( - batch_size, - self._num_items + 1, - ) # (batch_size, num_items) - candidate_scores[:, 0] = -torch.inf # zero (padding) token - candidate_scores[ - :, - self._num_items + 1 :, - ] = -torch.inf # all not real items-related things - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/models/rqvae.py b/src/irec/models/rqvae.py new file mode 100644 index 00000000..5cdfe784 --- /dev/null +++ b/src/irec/models/rqvae.py @@ -0,0 +1,260 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, input_dim, output_dim, dropout=0.1): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + + self.norm = nn.LayerNorm(input_dim) + self.layer = nn.Linear(input_dim, output_dim) + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + embedding = x + embedding = self.norm(embedding) + embedding = self.layer(embedding) + embedding = self.act(embedding) + embedding = self.dropout(embedding) + + if self.input_dim == self.output_dim: + return embedding + x + return embedding + + +class Tower(nn.Module): + def __init__(self, dims, dropout): + super().__init__() + self.layers = nn.ModuleList() + for i in range(len(dims) - 1): + self.layers.append(ResidualBlock(dims[i], dims[i + 1], dropout)) + + def forward(self, x): + embedding = x + for layer in self.layers: + embedding = layer(embedding) + return embedding + + +class VectorQuantizer(nn.Module): + + def __init__( + self, + codebook_size, + embedding_dim, + mu=0.25, + ): + super().__init__() + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.mu = mu + + self.embedding = nn.Embedding(self.codebook_size, self.embedding_dim) + + def get_codebook(self): + return self.embedding.weight + + def forward(self, latent_embeddings): + # Get closest centroids + d = torch.sum(latent_embeddings**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t() - 2 * torch.matmul(latent_embeddings, self.embedding.weight.t()) + indices = torch.argmin(d, dim=-1) + + x_q = self.embedding(indices) + + # compute loss for embedding + commitment_loss = F.mse_loss(x_q.detach(), latent_embeddings) + codebook_loss = F.mse_loss(x_q, latent_embeddings.detach()) + + quantization_loss = codebook_loss + self.mu * commitment_loss + + # preserve gradients + x_q = latent_embeddings + (x_q - latent_embeddings).detach() + + indices = indices.view(latent_embeddings.shape[:-1]) + + return x_q, quantization_loss, indices + + +class ResidualVectorQuantizer(nn.Module): + def __init__( + self, + num_codebooks, + codebook_size, + embedding_dim, + ): + super().__init__() + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + + self.vq_layers: list[VectorQuantizer] = nn.ModuleList([ + VectorQuantizer(codebook_size, embedding_dim) for _ in range(num_codebooks) + ]) + + def forward(self, latent_embeddings): + all_losses = [] + all_indices = [] + + x_q = 0 + residual = latent_embeddings + + for quantizer in self.vq_layers: + x_res, loss, indices = quantizer(residual) + residual = residual - x_res + x_q = x_q + x_res + + all_losses.append(loss) + all_indices.append(indices) + + mean_losses = torch.stack(all_losses).mean() + all_indices = torch.stack(all_indices, dim=-1) + + return x_q, mean_losses, all_indices + + +class RQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + layers, + dropout_prob=0.0, + beta=0.25, + quant_loss_weight=1.0, + cf_loss_weight=1.0, + cf_embeddings=None + ): + super().__init__() + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.beta = beta + self.quant_loss_weight = quant_loss_weight + + self.layers = layers + self.dropout_prob = dropout_prob + self.cf_embeddings = cf_embeddings + self.cf_loss_weight = cf_loss_weight + + self.encoder_layer_dims = [self.input_dim] + self.layers + [self.embedding_dim] + self.decoder_layer_dims = self.encoder_layer_dims[::-1] + + # TODO add inizialisation with AE + self.encoder = Tower( + dims=self.encoder_layer_dims, + dropout=self.dropout_prob + ) + self.decoder = Tower( + dims=self.decoder_layer_dims, + dropout=self.dropout_prob + ) + + self.rq = ResidualVectorQuantizer( + num_codebooks=num_codebooks, + codebook_size=codebook_size, + embedding_dim=embedding_dim + ) + + @staticmethod + def get_codebook_indices(remainder, quantizer): + dist = torch.sum(remainder**2, dim=1, keepdim=True) + torch.sum(quantizer.embedding.weight**2, dim=1, keepdim=True).t() - 2 * torch.matmul(remainder, quantizer.embedding.weight.t()) + return dist.argmin(dim=-1) + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + for quantizer in self.rq.vq_layers: + codebook_indices = self.get_codebook_indices(remainder, quantizer) + clusters.append(codebook_indices) + + quantized = quantizer.embedding(codebook_indices) + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + # codebook_vectors, quantizer_loss, codebook_indices = quantizer(remainder) + # rqvae_loss += quantizer_loss + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = F.mse_loss(embeddings_restored, inputs['embedding']) + + # TODO for now + # if self.cf_embeddings is not None: + # cf_embedding_in_batch = self.cf_embeddings[item_ids] + # cf_embedding_in_batch = torch.from_numpy(cf_embedding_in_batch).to(quantized_embeddings.device) + # cf_loss = self.CF_loss(quantized_embeddings, cf_embedding_in_batch) + # else: + cf_loss = torch.as_tensor(0.0) + + loss = (recon_loss + self.quant_loss_weight * rqvae_loss + self.cf_loss_weight * cf_loss).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + # loss, recon_loss, cf_loss, rq_loss = self.compute_loss( + # content_embeddings=content_embeddings, + # out_embeddings=out_embeddings, + # item_ids=item_ids, + # rq_loss=rq_loss, + # quantized_embeddings=quantized_embeddings + # ) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + 'cf_loss': cf_loss.item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } + + # def CF_loss(self, quantized_rep, encoded_rep): + # batch_size = quantized_rep.size(0) + # labels = torch.arange(batch_size, dtype=torch.long, device=quantized_rep.device) + # similarities = quantized_rep @ encoded_rep.T + # cf_loss = F.cross_entropy(similarities, labels) + # return cf_loss + + # @torch.no_grad() + # def get_indices(self, content_embeddings): + # latent_embeddings = self.encoder(content_embeddings) + # _, _, indices = self.rq(latent_embeddings) + # return indices + + # def compute_loss(self, content_embeddings, out_embeddings, item_ids, rq_loss, quantized_embeddings): + # if self.loss_type == 'mse': + # recon_loss = F.mse_loss(content_embeddings, out_embeddings, reduction='mean') + # elif self.loss_type == 'l1': + # recon_loss = F.l1_loss(content_embeddings, out_embeddings, reduction='mean') + # else: + # raise ValueError('incompatible loss type') + + # if self.cf_embeddings is not None: + # cf_embedding_in_batch = self.cf_embeddings[item_ids] + # cf_embedding_in_batch = torch.from_numpy(cf_embedding_in_batch).to(quantized_embeddings.device) + # cf_loss = self.CF_loss(quantized_embeddings, cf_embedding_in_batch) + # else: + # cf_loss = torch.as_tensor(0.0) + + # total_loss = recon_loss + self.quant_loss_weight * rq_loss + self.cf_loss_weight * cf_loss + + # return total_loss, recon_loss, cf_loss, rq_loss \ No newline at end of file diff --git a/src/irec/models/s3rec.py b/src/irec/models/s3rec.py deleted file mode 100644 index fd5fdbd1..00000000 --- a/src/irec/models/s3rec.py +++ /dev/null @@ -1,267 +0,0 @@ -from .base import SequentialTorchModel - -import torch -import torch.nn as nn - -from irec.utils import create_masked_tensor - - -class S3RecModel(SequentialTorchModel, config_name='s3rec'): - def __init__( - self, - sequence_prefix, - positive_prefix, - negative_prefix, - sequence_segment_prefix, - positive_segment_prefix, - negative_segment_prefix, - candidate_prefix, - num_items, - max_sequence_length, - is_training, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-5, - initializer_range=0.02, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=is_training, - ) - self._sequence_prefix = sequence_prefix - self._positive_prefix = positive_prefix - self._negative_prefix = negative_prefix - self._sequence_segment_prefix = sequence_segment_prefix - self._positive_segment_prefix = positive_segment_prefix - self._negative_segment_prefix = negative_segment_prefix - self._candidate_prefix = candidate_prefix - self._is_training = is_training - self._mask_item_idx = num_items + 1 - - self.aap_norm = nn.Linear(embedding_dim, embedding_dim) - self.mip_norm = nn.Linear(embedding_dim, embedding_dim) - self.map_norm = nn.Linear(embedding_dim, embedding_dim) - self.sp_norm = nn.Linear(embedding_dim, embedding_dim) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - negative_prefix=config['negative_prefix'], - sequence_segment_prefix=config['sequence_segment_prefix'], - positive_segment_prefix=config['positive_segment_prefix'], - negative_segment_prefix=config['negative_segment_prefix'], - candidate_prefix=config['candidate_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - is_training=config['is_training'], - embedding_dim=config['embedding_dim'], - num_heads=config.get( - 'num_heads', - int(config['embedding_dim'] // 64), - ), - num_layers=config['num_layers'], - dim_feedforward=config.get( - 'dim_feedforward', - 4 * config['embedding_dim'], - ), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - ) - - def masked_item_prediction( - self, - sequence_embeddings, - sequence_mask, - target_item, - ): - all_items = sequence_embeddings[ - sequence_mask - ] # (all_batch_items, emb_dim) - score = torch.einsum( - 'ad,ad->a', - all_items, - target_item, - ) # (all_batch_items) - return torch.sigmoid(score) # (all_batch_items) - - def segment_prediction(self, context, segment): - score = torch.einsum( - 'bd,bd->b', - self.sp_norm(context), - segment, - ) # (batch_size) - return torch.sigmoid(score) # (batch_size) - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, - all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - last_embeddings = self._get_last_embedding( - embeddings, - mask, - ) # (batch_size, embedding_dim) - - if self._is_training: - if self.training: # training mode - all_positive_sample_events = inputs[ - '{}.ids'.format(self._positive_prefix) - ] # (all_batch_events) - all_negative_sample_events = inputs[ - '{}.ids'.format(self._negative_prefix) - ] # (all_batch_events) - - all_sample_embeddings = embeddings[ - mask - ] # (all_batch_events, embedding_dim) - all_positive_sample_embeddings = self._item_embeddings( - all_positive_sample_events, - ) # (all_batch_events, embedding_dim) - all_negative_sample_embeddings = self._item_embeddings( - all_negative_sample_events, - ) # (all_batch_events, embedding_dim) - - return { - 'current_embeddings': all_sample_embeddings, - 'positive_embeddings': all_positive_sample_embeddings, - 'negative_embeddings': all_negative_sample_embeddings, - } - else: # eval mode - if '{}.ids'.format(self._candidate_prefix) in inputs: - candidate_events = inputs[ - '{}.ids'.format(self._candidate_prefix) - ] # (all_batch_candidates) - candidate_lengths = inputs[ - '{}.length'.format(self._candidate_prefix) - ] # (batch_size) - - candidate_embeddings = self._item_embeddings( - candidate_events, - ) # (all_batch_candidates, embedding_dim) - - candidate_embeddings, _ = create_masked_tensor( - data=candidate_embeddings, - lengths=candidate_lengths, - ) # (batch_size, num_candidates, embedding_dim) - - candidate_scores = torch.einsum( - 'bd,bnd->bn', - last_embeddings, - candidate_embeddings, - ) # (batch_size, num_candidates) - else: - candidate_embeddings = ( - self._item_embeddings.weight - ) # (num_items, embedding_dim) - candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - candidate_embeddings, - ) # (batch_size, num_items) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - return candidate_scores - else: - # Masked Item Prediction - mip_mask = ( - all_sample_events == self._mask_item_idx - ).bool() # (all_batch_events) - embeddings = embeddings[mask][ - mip_mask - ] # (all_batch_events, embedding_dim) - positive_item_events = inputs[ - '{}.ids'.format(self._positive_prefix) - ][mip_mask] # (all_batch_events) - negative_item_events = inputs[ - '{}.ids'.format(self._negative_prefix) - ][mip_mask] # (all_batch_events) - - positive_item_embeddings = self._item_embeddings( - positive_item_events, - ) # (all_batch_events, embedding_dim) - negative_item_embeddings = self._item_embeddings( - negative_item_events, - ) # (all_batch_events, embedding_dim) - - # Sequence Prediction - all_segment_events = inputs[ - '{}.ids'.format(self._sequence_segment_prefix) - ] # (all_batch_events) - all_segment_lengths = inputs[ - '{}.length'.format(self._sequence_segment_prefix) - ] # (batch_size) - segment_embeddings, segment_mask = self._apply_sequential_encoder( - all_segment_events, - all_segment_lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - last_segment_embeddings = self._get_last_embedding( - segment_embeddings, - segment_mask, - ) # (batch_size, embedding_dim) - - positive_segment_events = inputs[ - '{}.ids'.format(self._positive_segment_prefix) - ] # (all_batch_events) - positive_segment_lengths = inputs[ - '{}.length'.format(self._positive_segment_prefix) - ] # (batch_size) - positive_segment_embeddings, positive_segment_mask = ( - self._apply_sequential_encoder( - positive_segment_events, - positive_segment_lengths, - ) - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - last_positive_segment_embeddings = self._get_last_embedding( - positive_segment_embeddings, - positive_segment_mask, - ) # (batch_size, embedding_dim) - - negative_segment_events = inputs[ - '{}.ids'.format(self._negative_segment_prefix) - ] # (all_batch_events) - negative_segment_lengths = inputs[ - '{}.length'.format(self._negative_segment_prefix) - ] # (batch_size) - negative_segment_embeddings, negative_segment_mask = ( - self._apply_sequential_encoder( - negative_segment_events, - negative_segment_lengths, - ) - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - last_negative_segment_embeddings = self._get_last_embedding( - negative_segment_embeddings, - negative_segment_mask, - ) # (batch_size, embedding_dim) - - return { - 'positive_representation': positive_item_embeddings, - 'negative_representation': negative_item_embeddings, - 'current_representation': embeddings, - 'positive_segment_representation': last_positive_segment_embeddings, - 'negative_segment_representation': last_negative_segment_embeddings, - 'current_segment_representation': last_segment_embeddings, - } diff --git a/src/irec/models/sasrec.py b/src/irec/models/sasrec.py deleted file mode 100644 index e97019a9..00000000 --- a/src/irec/models/sasrec.py +++ /dev/null @@ -1,221 +0,0 @@ -from irec.models import SequentialTorchModel -from irec.utils import create_masked_tensor - -import torch - - -class SasRecModel(SequentialTorchModel, config_name='sasrec'): - - def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02 - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=True - ) - self._sequence_prefix = sequence_prefix - self._positive_prefix = positive_prefix - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get('num_heads', int(config['embedding_dim'] // 64)), - num_layers=config['num_layers'], - dim_feedforward=config.get('dim_feedforward', 4 * config['embedding_dim']), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02) - ) - - def forward(self, inputs): - all_sample_events = inputs['{}.ids'.format(self._sequence_prefix)] # (all_batch_events) - all_sample_lengths = inputs['{}.length'.format(self._sequence_prefix)] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, all_sample_lengths - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - if self.training: # training mode - all_positive_sample_events = inputs['{}.ids'.format(self._positive_prefix)] # (all_batch_events) - - all_sample_embeddings = embeddings[mask] # (all_batch_events, embedding_dim) - - all_embeddings = self._item_embeddings.weight # (num_items, embedding_dim) - - # a -- all_batch_events, n -- num_items, d -- embedding_dim - all_scores = torch.einsum( - 'ad,nd->an', - all_sample_embeddings, - all_embeddings - ) # (all_batch_events, num_items) - - positive_scores = torch.gather( - input=all_scores, - dim=1, - index=all_positive_sample_events[..., None] - )[:, 0] # (all_batch_items) - - negative_scores = torch.gather( - input=all_scores, - dim=1, - index=torch.randint(low=0, high=all_scores.shape[1], size=all_positive_sample_events.shape, device=all_positive_sample_events.device)[..., None] - )[:, 0] # (all_batch_items) - - # sample_ids, _ = create_masked_tensor( - # data=all_sample_events, - # lengths=all_sample_lengths - # ) # (batch_size, seq_len) - - # sample_ids = torch.repeat_interleave(sample_ids, all_sample_lengths, dim=0) # (all_batch_events, seq_len) - - # negative_scores = torch.scatter( - # input=all_scores, - # dim=1, - # index=sample_ids, - # src=torch.ones_like(sample_ids) * (-torch.inf) - # ) # (all_batch_events, num_items) - - return { - 'positive_scores': positive_scores, - 'negative_scores': negative_scores - } - else: # eval mode - last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim) - # b - batch_size, n - num_candidates, d - embedding_dim - candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight - ) # (batch_size, num_items + 2) - - _, indices = torch.topk( - candidate_scores, - k=50, dim=-1, largest=True - ) # (batch_size, 20) - - return indices - - -class SasRecInBatchModel(SasRecModel, config_name='sasrec_in_batch'): - - def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02 - ): - super().__init__( - sequence_prefix=sequence_prefix, - positive_prefix=positive_prefix, - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - initializer_range=initializer_range - ) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get('num_heads', int(config['embedding_dim'] // 64)), - num_layers=config['num_layers'], - dim_feedforward=config.get('dim_feedforward', 4 * config['embedding_dim']), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02) - ) - - def forward(self, inputs): - all_sample_events = inputs['{}.ids'.format(self._sequence_prefix)] # (all_batch_events) - all_sample_lengths = inputs['{}.length'.format(self._sequence_prefix)] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, all_sample_lengths - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - if self.training: # training mode - # queries - in_batch_queries_embeddings = embeddings[mask] # (all_batch_events, embedding_dim) - - # positives - in_batch_positive_events = inputs['{}.ids'.format(self._positive_prefix)] # (all_batch_events) - in_batch_positive_embeddings = self._item_embeddings( - in_batch_positive_events - ) # (all_batch_events, embedding_dim) - - # negatives - batch_size = all_sample_lengths.shape[0] - random_ids = torch.randperm(in_batch_positive_events.shape[0]) - in_batch_negative_ids = in_batch_positive_events[random_ids][:batch_size] - - in_batch_negative_embeddings = self._item_embeddings( - in_batch_negative_ids - ) # (batch_size, embedding_dim) - - return { - 'query_embeddings': in_batch_queries_embeddings, - 'positive_embeddings': in_batch_positive_embeddings, - 'negative_embeddings': in_batch_negative_embeddings - } - else: # eval mode - last_embeddings = self._get_last_embedding(embeddings, mask) # (batch_size, embedding_dim) - - # b - batch_size, n - num_candidates, d - embedding_dim - candidate_scores = torch.einsum( - 'bd,nd->bn', - last_embeddings, - self._item_embeddings.weight - ) # (batch_size, num_items + 2) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1:] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=50, dim=-1, largest=True - ) # (batch_size, 20) - - return indices \ No newline at end of file diff --git a/src/irec/models/sasrec_ce.py b/src/irec/models/sasrec_ce.py deleted file mode 100644 index c07b9ff6..00000000 --- a/src/irec/models/sasrec_ce.py +++ /dev/null @@ -1,108 +0,0 @@ -from .base import SequentialTorchModel - -import torch -import torch.nn as nn - - -class SasRecCeModel(SequentialTorchModel, config_name='sasrec_ce'): - def __init__( - self, - sequence_prefix, - positive_prefix, - num_items, - max_sequence_length, - embedding_dim, - num_heads, - num_layers, - dim_feedforward, - dropout=0.0, - activation='relu', - layer_norm_eps=1e-9, - initializer_range=0.02, - ): - super().__init__( - num_items=num_items, - max_sequence_length=max_sequence_length, - embedding_dim=embedding_dim, - num_heads=num_heads, - num_layers=num_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - is_causal=True, - ) - self._sequence_prefix = sequence_prefix - self._positive_prefix = positive_prefix - - self._output_projection = nn.Linear( - in_features=embedding_dim, - out_features=embedding_dim, - ) - - self._init_weights(initializer_range) - - @classmethod - def create_from_config(cls, config, **kwargs): - return cls( - sequence_prefix=config['sequence_prefix'], - positive_prefix=config['positive_prefix'], - num_items=kwargs['num_items'], - max_sequence_length=kwargs['max_sequence_length'], - embedding_dim=config['embedding_dim'], - num_heads=config.get( - 'num_heads', - int(config['embedding_dim'] // 64), - ), - num_layers=config['num_layers'], - dim_feedforward=config.get( - 'dim_feedforward', - 4 * config['embedding_dim'], - ), - dropout=config.get('dropout', 0.0), - initializer_range=config.get('initializer_range', 0.02), - ) - - def forward(self, inputs): - all_sample_events = inputs[ - '{}.ids'.format(self._sequence_prefix) - ] # (all_batch_events) - all_sample_lengths = inputs[ - '{}.length'.format(self._sequence_prefix) - ] # (batch_size) - - embeddings, mask = self._apply_sequential_encoder( - all_sample_events, - all_sample_lengths, - ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) - - embeddings = self._output_projection( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = torch.nn.functional.gelu( - embeddings, - ) # (batch_size, seq_len, embedding_dim) - embeddings = torch.einsum( - 'bsd,nd->bsn', - embeddings, - self._item_embeddings.weight, - ) # (batch_size, seq_len, num_items + 2) - - if self.training: # training mode - return {'logits': embeddings[mask]} - else: # eval mode - candidate_scores = self._get_last_embedding( - embeddings, - mask, - ) # (batch_size, num_items + 2) - candidate_scores[:, 0] = -torch.inf - candidate_scores[:, self._num_items + 1 :] = -torch.inf - - _, indices = torch.topk( - candidate_scores, - k=20, - dim=-1, - largest=True, - ) # (batch_size, 20) - - return indices diff --git a/src/irec/optimizer/__init__.py b/src/irec/optimizer/__init__.py deleted file mode 100644 index 0ba2efec..00000000 --- a/src/irec/optimizer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base import BaseOptimizer - -__all__ = ['BaseOptimizer'] diff --git a/src/irec/optimizer/base.py b/src/irec/optimizer/base.py deleted file mode 100644 index 0fc1f70a..00000000 --- a/src/irec/optimizer/base.py +++ /dev/null @@ -1,78 +0,0 @@ -import copy - -from irec.utils import MetaParent - -import torch - -OPTIMIZERS = { - 'sgd': torch.optim.SGD, - 'adam': torch.optim.Adam, - 'adamw': torch.optim.AdamW, -} - -SCHEDULERS = { - 'step': torch.optim.lr_scheduler.StepLR, - 'cyclic': torch.optim.lr_scheduler.CyclicLR, -} - - -class BaseOptimizer(metaclass=MetaParent): - pass - - -class BasicOptimizer(BaseOptimizer, config_name='basic'): - def __init__( - self, - model, - optimizer, - scheduler=None, - clip_grad_threshold=None, - ): - self._model = model - self._optimizer = optimizer - self._scheduler = scheduler - self._clip_grad_threshold = clip_grad_threshold - - @classmethod - def create_from_config(cls, config, **kwargs): - optimizer_cfg = copy.deepcopy(config['optimizer']) - optimizer = OPTIMIZERS[optimizer_cfg.pop('type')]( - kwargs['model'].parameters(), - **optimizer_cfg, - ) - - if 'scheduler' in config: - scheduler_cfg = copy.deepcopy(config['scheduler']) - scheduler = SCHEDULERS[scheduler_cfg.pop('type')]( - optimizer, - **scheduler_cfg, - ) - else: - scheduler = None - - return cls( - model=kwargs['model'], - optimizer=optimizer, - scheduler=scheduler, - clip_grad_threshold=config.get('clip_grad_threshold', None), - ) - - def step(self, loss): - self._optimizer.zero_grad() - loss.backward() - - if self._clip_grad_threshold is not None: - torch.nn.utils.clip_grad_norm_( - self._model.parameters(), - self._clip_grad_threshold, - ) - - self._optimizer.step() - if self._scheduler is not None: - self._scheduler.step() - - def state_dict(self): - state_dict = {'optimizer': self._optimizer.state_dict()} - if self._scheduler is not None: - state_dict.update({'scheduler': self._scheduler.state_dict()}) - return state_dict diff --git a/src/irec/pretrain.py b/src/irec/pretrain.py deleted file mode 100644 index e557c2fc..00000000 --- a/src/irec/pretrain.py +++ /dev/null @@ -1,118 +0,0 @@ -import irec.utils -from irec.utils import ( - parse_args, - create_logger, - fix_random_seed, - DEVICE, - ensure_checkpoints_dir, -) - -from irec.dataset import BaseDataset -from irec.dataloader import BaseDataloader -from irec.models import BaseModel -from irec.optimizer import BaseOptimizer -from irec.loss import BaseLoss -from irec.callbacks import BaseCallback - -import copy -import json -import torch - -logger = create_logger(name=__name__) -seed_val = 42 - - -def pretrain(dataloader, model, optimizer, loss_function, callback, epoch_cnt): - step_num = 0 - best_checkpoint = None - - logger.debug('Start pretraining...') - - for epoch in range(epoch_cnt): - logger.debug(f'Start epoch {epoch}') - for step, batch in enumerate(dataloader): - model.train() - - for key, values in batch.items(): - batch[key] = batch[key].to(DEVICE) - - batch.update(model.pretrain(batch)) - loss = loss_function(batch) - - optimizer.step(loss) - callback(batch, step_num) - step_num += 1 - - best_checkpoint = copy.deepcopy(model.state_dict()) - - logger.debug('Pretraining procedure has been finished!') - return best_checkpoint - - -def main(): - fix_random_seed(seed_val) - config = parse_args() - - irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER = ( - irec.utils.tensorboards.TensorboardWriter(config['experiment_name']) - ) - - logger.debug('Training config: \n{}'.format(json.dumps(config, indent=2))) - - dataset = BaseDataset.create_from_config(config['dataset']) - - train_sampler, test_sampler = dataset.get_samplers() - - train_dataloader = BaseDataloader.create_from_config( - config['dataloader']['train'], - dataset=train_sampler, - **dataset.meta, - ) - - validation_dataloader = BaseDataloader.create_from_config( - config['dataloader']['validation'], - dataset=test_sampler, - **dataset.meta, - ) - - model = BaseModel.create_from_config(config['model'], **dataset.meta).to( - DEVICE, - ) - - loss_function = BaseLoss.create_from_config(config['loss']) - - optimizer = BaseOptimizer.create_from_config( - config['optimizer'], - model=model, - ) - - callback = BaseCallback.create_from_config( - config['callback'], - model=model, - dataloader=validation_dataloader, - optimizer=optimizer, - ) - - logger.debug('Everything is ready for pretraining process!') - - # Pretrain process - pretrain( - dataloader=train_dataloader, - model=model, - optimizer=optimizer, - loss_function=loss_function, - callback=callback, - epoch_cnt=config['pretrain_epochs_num'], - ) - - logger.debug('Saving model...') - ensure_checkpoints_dir() - checkpoint_path = './checkpoints/pretrain_{}_final_state.pth'.format( - config['experiment_name'], - ) - torch.save(model.state_dict(), checkpoint_path) - logger.debug('Saved model as {}'.format(checkpoint_path)) - - -if __name__ == '__main__': - main() diff --git a/src/irec/runners/__init__.py b/src/irec/runners/__init__.py new file mode 100644 index 00000000..a8e8257d --- /dev/null +++ b/src/irec/runners/__init__.py @@ -0,0 +1,18 @@ +from irec.runners.base import Runner, RunnerContext, BatchRunner, BatchRunnerContext +from irec.runners.train import TrainingRunner, TrainingRunnerContext +from irec.runners.inference import InferenceRunner, InferenceRunnerContext + + +__all__ = [ + 'Runner', + 'RunnerContext', + + 'BatchRunner', + 'BatchRunnerContext', + + 'TrainingRunner', + 'TrainingRunnerContext', + + 'InferenceRunner', + 'InferenceRunnerContext', +] \ No newline at end of file diff --git a/src/irec/runners/base.py b/src/irec/runners/base.py new file mode 100644 index 00000000..40828ba4 --- /dev/null +++ b/src/irec/runners/base.py @@ -0,0 +1,155 @@ +import dataclasses +from typing import Any, Dict, Union + +import torch + +import irec.callbacks as cb + + +@dataclasses.dataclass +class RunnerContext: + metrics: Dict[str, Union[int, float]] = dataclasses.field(default_factory=dict) + + +@dataclasses.dataclass +class BatchRunnerContext(RunnerContext): + batch: Any = None + + +class Runner: + def __init__(self, callbacks: cb.Callback): + super().__init__() + self._callback = cb.Composite(*callbacks, declared_events=self.declared_events) + self._global_step = 0 + self._global_finished = False + + @property + def callback(self): + return self._callback + + @property + def global_step(self): + return self._global_step + + @property + def global_finished(self): + return self._global_finished + + def state_dict(self): + return { + 'callback': self._callback.state_dict(), + 'global_step': self._global_step, + 'global_finished': self._global_finished + } + + def load_state_dict(self, state_dict): + self._callback.load_state_dict(state_dict['callback']) + self._global_step = state_dict['global_step'] + self._global_finished = state_dict['global_finished'] + + def run(self): + self._callback.before_run(self) + self._callback.load_snapshot(self) + while not self._global_finished: + try: + context = self._run_step() + self._callback.after_step(self, context) + self._callback.save_snapshot(self) + except StopIteration: + self._global_finished = True + context = self._create_context() + self._callback.save_snapshot(self) + self._callback.after_run(self, context) + return context + + # Only these two functions below should be re-implemented in other runners + @property + def declared_events(self): + return cb.Callback.declared_events + + def _create_context(self): + return RunnerContext() + + def _run_step(self): + context = self._create_context() + return context + + +class BatchRunner(Runner): + def __init__(self, dataset, callbacks): + super().__init__(callbacks) + self._dataset = dataset + self._dataset_iterator = None + + @property + def dataset(self): + return self._dataset + + @property + def dataset_iterator(self): + return self._dataset_iterator + + # TODO think + # @property + # def dataset_has_state(self): + # # TODO: Maybe use runtime checkable Stateful protocol? + # # https://pytorch.org/docs/stable/_modules/torch/distributed/checkpoint/stateful.html#Stateful + # return (callable(getattr(self._dataset_iterator, 'state_dict', None)) and + # callable(getattr(self._dataset_iterator, 'load_state_dict', None))) + + def state_dict(self): + state_dict = super().state_dict() + state_dict.setdefault('distributed_state_dict', {}) + assert 'dataset' not in state_dict['distributed_state_dict'] + state_dict['distributed_state_dict']['dataset'] = self._dataset_iterator.state_dict() if self.dataset_has_state else None + return state_dict + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + if self.dataset_has_state: + self._dataset_iterator.load_state_dict(state_dict['distributed_state_dict']['dataset']) + + # TODO I dont'like + def run(self): + self._dataset_iterator = iter(self._dataset) + # self._dataset_iterator = self._dataset + # self._dataset_iterator = SequenceIterator(self._dataset) + # Sequential dataset + # if hasattr(self._dataset, '__getitem__'): + # # Streaming dataset + # elif hasattr(self._dataset, '__next__'): + # self._dataset_iterator = self._dataset + # # Generator + # elif hasattr(self._dataset, '__iter__'): + # self._dataset_iterator = iter(self._dataset) + # else: + # raise TypeError(f'Dataset expected to be iterator, iterable or sequence, got {type(self._dataset)}') + + return super().run() + + @property + def declared_events(self): + return cb.BatchCallback.declared_events + + def _create_context(self): + return BatchRunnerContext() + + def _run_step(self): + context = self._create_context() + self._global_step += 1 + self._callback.before_load(self, context) + context.batch = next(self._dataset_iterator) + self._callback.before_process_batch(self, context) + self._process_batch(context) + return context + + def _process_batch(self, context): + pass + + +__all__ = [ + 'BatchRunner', + 'BatchRunnerContext', + 'Runner', + 'RunnerContext', +] diff --git a/src/irec/runners/inference.py b/src/irec/runners/inference.py new file mode 100644 index 00000000..f73bb8ea --- /dev/null +++ b/src/irec/runners/inference.py @@ -0,0 +1,45 @@ +import dataclasses +from typing import Any + +import torch + +from irec.runners.base import BatchRunner, BatchRunnerContext + + +@dataclasses.dataclass +class InferenceRunnerContext(BatchRunnerContext): + model_outputs: Any = None + + +class InferenceRunner(BatchRunner): + def __init__( + self, + model: torch.nn.Module, + dataset, + callbacks + ): + super().__init__( + dataset=dataset, + callbacks=callbacks + ) + self._model = model + + @property + def model(self): + return self._model + + def run(self): + with torch.inference_mode(mode=True): + training = self._model.training + try: + self._model.eval() + return super().run() + finally: + self._model.train(training) + + def _process_batch(self, context: InferenceRunnerContext): + _, context.model_outputs = self._model(context.batch) + + def _create_context(self): + return InferenceRunnerContext() + diff --git a/src/irec/runners/train.py b/src/irec/runners/train.py new file mode 100644 index 00000000..71b1d74d --- /dev/null +++ b/src/irec/runners/train.py @@ -0,0 +1,85 @@ +import dataclasses +from typing import Any + +import torch + +import irec.callbacks as cb + +from .base import BatchRunner, BatchRunnerContext + + +@dataclasses.dataclass +class TrainingRunnerContext(BatchRunnerContext): + model_outputs: Any = None + training_loss: Any = None + + +class TrainingRunner(BatchRunner): + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + dataset, + callbacks + ): + super().__init__( + dataset=dataset, + callbacks=callbacks + ) + self._model = model + self._optimizer = optimizer + + @property + def model(self): + return self._model + + @property + def optimizer(self): + return self._optimizer + + def state_dict(self): + state_dict = super().state_dict() + assert {'model', 'optimizer'}.isdisjoint(state_dict) + state_dict['model'] = self._model.state_dict() + state_dict['optimizer'] = self._optimizer.state_dict() + return state_dict + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self._model.load_state_dict(state_dict['model'], strict=True) + self._optimizer.load_state_dict(state_dict['optimizer']) + + def run(self): + assert self._model.training + self._optimizer.zero_grad() + return super().run() + + @property + def declared_events(self): + return cb.TrainingCallback.declared_events + + def _create_context(self): + return TrainingRunnerContext() + + def _process_batch(self, context: TrainingRunnerContext): + assert self._model.training + self._run_forward(context) + self._run_backward(context) + self._callback.before_optimizer(self, context) + self._run_optimizer(context) + + def _run_forward(self, context: TrainingRunnerContext): + context.training_loss, context.model_outputs = self._model(context.batch) + + def _run_backward(self, context: TrainingRunnerContext): + context.training_loss.backward() + + def _run_optimizer(self, context: TrainingRunnerContext): + self._optimizer.step() + self._optimizer.zero_grad() + + +__all__ = [ + 'TrainingRunner', + 'TraninngRunnerContext' +] diff --git a/src/irec/scheduler/base.py b/src/irec/scheduler/base.py deleted file mode 100644 index 334cef4b..00000000 --- a/src/irec/scheduler/base.py +++ /dev/null @@ -1,5 +0,0 @@ -from irec.utils import MetaParent - - -class BaseScheduler(metaclass=MetaParent): - pass diff --git a/src/irec/train.py b/src/irec/train.py deleted file mode 100644 index 78444396..00000000 --- a/src/irec/train.py +++ /dev/null @@ -1,200 +0,0 @@ -import irec.utils -from irec.utils import ( - parse_args, - create_logger, - DEVICE, - fix_random_seed, - ensure_checkpoints_dir, -) - -from irec.callbacks import BaseCallback -from irec.dataset import BaseDataset -from irec.dataloader import BaseDataloader -from irec.loss import BaseLoss -from irec.models import BaseModel -from irec.optimizer import BaseOptimizer - -import copy -import json -import os -import torch -import wandb - -logger = create_logger(name=__name__) -seed_val = 42 - - -def train( - dataloader, - model, - optimizer, - loss_function, - callback, - epoch_cnt=None, - step_cnt=None, - best_metric=None, -): - step_num = 0 - epoch_num = 0 - current_metric = 0 - - epochs_threshold = 40 - - best_epoch = 0 - best_checkpoint = None - - logger.debug('Start training...') - - while (epoch_cnt is None or epoch_num < epoch_cnt) and ( - step_cnt is None or step_num < step_cnt - ): - if best_epoch + epochs_threshold < epoch_num: - logger.debug( - 'There is no progress during {} epochs. Finish training'.format( - epochs_threshold, - ), - ) - break - - logger.debug(f'Start epoch {epoch_num}') - for step, batch in enumerate(dataloader): - batch_ = copy.deepcopy(batch) - - model.train() - - for key, values in batch_.items(): - batch_[key] = batch_[key].to(DEVICE) - - batch_.update(model(batch_)) - loss = loss_function(batch_) - - optimizer.step(loss) - callback(batch_, step_num) - step_num += 1 - - if best_metric is None: - # Take the last model - best_checkpoint = copy.deepcopy(model.state_dict()) - best_epoch = epoch_num - elif ( - best_checkpoint is None - or best_metric in batch_ - and current_metric <= batch_[best_metric] - ): - # If it is the first checkpoint, or it is the best checkpoint - current_metric = batch_[best_metric] - best_checkpoint = copy.deepcopy(model.state_dict()) - best_epoch = epoch_num - - epoch_num += 1 - logger.debug('Training procedure has been finished!') - return best_checkpoint - - -def main(): - fix_random_seed(seed_val) - config = parse_args() - - if config.get('use_wandb', False): - wandb.init( - project='irec', - name=config['experiment_name'], - sync_tensorboard=True, - ) - - tensorboard_writer = irec.utils.tensorboards.TensorboardWriter(config['experiment_name']) - irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER = tensorboard_writer - - log_dir = tensorboard_writer.log_dir - config_save_path = os.path.join(log_dir, 'config.json') - with open(config_save_path, 'w') as f: - json.dump(config, f, indent=2) - - logger.debug('Training config: \n{}'.format(json.dumps(config, indent=2))) - logger.debug('Current DEVICE: {}'.format(DEVICE)) - logger.info(f"Experiment config saved to: {config_save_path}") - - - dataset = BaseDataset.create_from_config(config['dataset']) - - train_sampler, validation_sampler, test_sampler = dataset.get_samplers() - - train_dataloader = BaseDataloader.create_from_config( - config['dataloader']['train'], - dataset=train_sampler, - **dataset.meta, - ) - - validation_dataloader = BaseDataloader.create_from_config( - config['dataloader']['validation'], - dataset=validation_sampler, - **dataset.meta, - ) - - eval_dataloader = BaseDataloader.create_from_config( - config['dataloader']['validation'], - dataset=test_sampler, - **dataset.meta, - ) - - model = BaseModel.create_from_config(config['model'], **dataset.meta).to( - DEVICE, - ) - if 'checkpoint' in config: - ensure_checkpoints_dir() - checkpoint_path = os.path.join( - './checkpoints', - f'{config["checkpoint"]}.pth', - ) - logger.debug('Loading checkpoint from {}'.format(checkpoint_path)) - checkpoint = torch.load(checkpoint_path) - logger.debug(checkpoint.keys()) - model.load_state_dict(checkpoint) - - loss_function = BaseLoss.create_from_config(config['loss']) - - optimizer = BaseOptimizer.create_from_config( - config['optimizer'], - model=model, - ) - - callback = BaseCallback.create_from_config( - config['callback'], - model=model, - train_dataloader=train_dataloader, - validation_dataloader=validation_dataloader, - eval_dataloader=eval_dataloader, - optimizer=optimizer, - **dataset.meta, - ) - - # TODO add verbose option for all callbacks, multiple optimizer options (???) - # TODO create pre/post callbacks - logger.debug('Everything is ready for training process!') - - # Train process - _ = train( - dataloader=train_dataloader, - model=model, - optimizer=optimizer, - loss_function=loss_function, - callback=callback, - epoch_cnt=config.get('train_epochs_num'), - step_cnt=config.get('train_steps_num'), - best_metric=config.get('best_metric'), - ) - - logger.debug('Saving model...') - ensure_checkpoints_dir() - checkpoint_path = './checkpoints/{}_final_state.pth'.format( - config['experiment_name'], - ) - torch.save(model.state_dict(), checkpoint_path) - logger.debug('Saved model as {}'.format(checkpoint_path)) - - if config.get('use_wandb', False): - wandb.finish() - - -if __name__ == '__main__': - main() diff --git a/src/irec/train_multiple.py b/src/irec/train_multiple.py deleted file mode 100644 index ede73a84..00000000 --- a/src/irec/train_multiple.py +++ /dev/null @@ -1,185 +0,0 @@ -import itertools -import json -import random -import torch - -import irec.utils -from irec.utils import ( - parse_args, - create_logger, - DEVICE, - Params, - dict_to_str, - fix_random_seed, - ensure_checkpoints_dir, -) - -from irec.train import train -from irec.infer import inference - -from irec.callbacks import BaseCallback, EvalCallback, ValidationCallback -from irec.dataset import BaseDataset -from irec.dataloader import BaseDataloader -from irec.loss import BaseLoss -from irec.models import BaseModel -from irec.optimizer import BaseOptimizer - -logger = create_logger(name=__name__) -seed_val = 42 - - -def main(): - fix_random_seed(seed_val) - config = parse_args() - - logger.debug('Training config: \n{}'.format(json.dumps(config, indent=2))) - - dataset_params = Params(config['dataset'], config['dataset_params']) - model_params = Params(config['model'], config['model_params']) - loss_function_params = Params(config['loss'], config['loss_params']) - optimizer_params = Params(config['optimizer'], config['optimizer_params']) - - logger.debug('Everything is ready for training process!') - - start_from = config.get('start_from', 0) - num = config.get('num_exps', None) - - list_of_params = list( - itertools.product( - dataset_params, - model_params, - loss_function_params, - optimizer_params, - ), - ) - - if num is None: - num = len(list_of_params) - else: - random.shuffle(list_of_params) - - cnt = 0 - - for ( - dataset_param, - model_param, - loss_param, - optimizer_param, - ) in list_of_params[start_from:num]: - cnt += 1 - if cnt < start_from: - continue - - model_name = '_'.join( - [ - config['experiment_name'], - dict_to_str(dataset_param, config['dataset_params']), - dict_to_str(model_param, config['model_params']), - dict_to_str(loss_param, config['loss_params']), - dict_to_str(optimizer_param, config['optimizer_params']), - ], - ) - - logger.debug('Starting {}'.format(model_name)) - - dataset = BaseDataset.create_from_config(dataset_param) - - train_sampler, validation_sampler, eval_sampler = ( - dataset.get_samplers() - ) - - train_dataloader = BaseDataloader.create_from_config( - config['dataloader']['train'], - dataset=train_sampler, - **dataset.meta, - ) - - validation_dataloader = BaseDataloader.create_from_config( - config['dataloader']['validation'], - dataset=validation_sampler, - **dataset.meta, - ) - - eval_dataloader = BaseDataloader.create_from_config( - config['dataloader']['validation'], - dataset=eval_sampler, - **dataset.meta, - ) - - if irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER is not None: - irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER.close() - irec.utils.tensorboards.GLOBAL_TENSORBOARD_WRITER = ( - irec.utils.tensorboards.TensorboardWriter( - model_name, - use_time=False, - ) - ) - - model = BaseModel.create_from_config(model_param, **dataset.meta).to( - DEVICE, - ) - loss_function = BaseLoss.create_from_config(loss_param) - optimizer = BaseOptimizer.create_from_config( - optimizer_param, - model=model, - ) - - callback = BaseCallback.create_from_config( - config['callback'], - model=model, - train_dataloader=train_dataloader, - validation_dataloader=validation_dataloader, - eval_dataloader=eval_dataloader, - optimizer=optimizer, - **dataset.meta, - ) - - best_model_checkpoint = train( - dataloader=train_dataloader, - model=model, - optimizer=optimizer, - loss_function=loss_function, - callback=callback, - epoch_cnt=config.get('train_epochs_num'), - best_metric=config.get('best_metric'), - ) - - eval_model = BaseModel.create_from_config( - model_param, - **dataset.meta, - ).to(DEVICE) - eval_model.load_state_dict(best_model_checkpoint) - - for cl in callback._callbacks: - if isinstance(cl, EvalCallback): - metrics = cl._metrics - pred_prefix = cl._pred_prefix - labels_prefix = cl._labels_prefix - break - else: - for cl in callback._callbacks: - if isinstance(cl, ValidationCallback): - metrics = cl._metrics - pred_prefix = cl._pred_prefix - labels_prefix = cl._labels_prefix - break - else: - assert False - - inference( - eval_dataloader, - eval_model, - metrics, - pred_prefix, - labels_prefix, - ) - - logger.debug('Saving best model checkpoint...') - ensure_checkpoints_dir() - checkpoint_path = './checkpoints/{}_final_state.pth'.format(model_name) - torch.save(best_model_checkpoint, checkpoint_path) - logger.debug('Saved model as {}'.format(checkpoint_path)) - - -if __name__ == '__main__': - main() diff --git a/src/irec/utils.py b/src/irec/utils.py new file mode 100644 index 00000000..87a6c7af --- /dev/null +++ b/src/irec/utils.py @@ -0,0 +1,16 @@ +import os +import random +import numpy as np +import torch + + +def fix_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + os.environ['PYTHONHASHSEED'] = str(seed) + \ No newline at end of file diff --git a/src/irec/utils/__init__.py b/src/irec/utils/__init__.py deleted file mode 100644 index 4010f140..00000000 --- a/src/irec/utils/__init__.py +++ /dev/null @@ -1,150 +0,0 @@ -from .registry import MetaParent -from .grid_search import Params -from .tensorboards import * - -__all__ = [ - 'MetaParent', - 'Params', -] - -import json -import random -import logging -import argparse -import numpy as np -import os - -import torch - -DEVICE = ( - torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') -) -# DEVICE = torch.device('cpu') - - -def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument('--params', required=True) - args = parser.parse_args() - - with open(args.params) as f: - params = json.load(f) - - return params - - -def create_logger( - name, - level=logging.DEBUG, - format='[%(asctime)s] [%(levelname)s]: %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', -): - logging.basicConfig(level=level, format=format, datefmt=datefmt) - logger = logging.getLogger(name) - return logger - - -def fix_random_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def maybe_to_list(values): - if not isinstance(values, list): - values = [values] - return values - - -def get_activation_function(name, **kwargs): - if name == 'relu': - return torch.nn.ReLU() - elif name == 'gelu': - return torch.nn.GELU() - elif name == 'elu': - return torch.nn.ELU(alpha=float(kwargs.get('alpha', 1.0))) - elif name == 'leaky': - return torch.nn.LeakyReLU( - negative_slope=float(kwargs.get('negative_slope', 1e-2)), - ) - elif name == 'sigmoid': - return torch.nn.Sigmoid() - elif name == 'tanh': - return torch.nn.Tanh() - elif name == 'softmax': - return torch.nn.Softmax() - elif name == 'softplus': - return torch.nn.Softplus( - beta=int(kwargs.get('beta', 1.0)), - threshold=int(kwargs.get('threshold', 20)), - ) - elif name == 'softmax_logit': - return torch.nn.LogSoftmax() - else: - raise ValueError('Unknown activation function name `{}`'.format(name)) - - -def dict_to_str(x, params): - parts = [] - for k, v in x.items(): - if k in params: - if isinstance(v, dict): - # part = '_'.join([f'{k}-{sub_part}' for sub_part in dict_to_str(v, params[k]).split('_')]) - part = '_'.join( - [ - f'{sub_part}' - for sub_part in dict_to_str(v, params[k]).split('_') - ], - ) - elif isinstance(v, tuple) or isinstance(v, list): - sub_strings = [] - for i, sub_value in enumerate(v): - sub_strings.append( - f'({i})_{dict_to_str(v[i], params[k][i])}', - ) - part = f'({"_".join(sub_strings)})' - else: - # part = f'{k}-{v}' - part = f'{v}' - parts.append(part) - else: - continue - return '_'.join(parts).replace('.', '-') - - -def create_masked_tensor(data, lengths): - batch_size = lengths.shape[0] - max_sequence_length = lengths.max().item() - - if len(data.shape) == 1: # only indices - padded_tensor = torch.zeros( - batch_size, - max_sequence_length, - dtype=data.dtype, - device=DEVICE, - ) # (batch_size, max_seq_len) - else: - assert len(data.shape) == 2 # embeddings - padded_tensor = torch.zeros( - batch_size, - max_sequence_length, - data.shape[-1], - dtype=data.dtype, - device=DEVICE, - ) # (batch_size, max_seq_len, emb_dim) - - mask = ( - torch.arange(end=max_sequence_length, device=DEVICE)[None].tile( - [batch_size, 1], - ) - < lengths[:, None] - ) # (batch_size, max_seq_len) - - padded_tensor[mask] = data - - return padded_tensor, mask - - -def ensure_checkpoints_dir(): - os.makedirs('./checkpoints', exist_ok=True) diff --git a/src/irec/utils/grid_search.py b/src/irec/utils/grid_search.py deleted file mode 100644 index acdda7cf..00000000 --- a/src/irec/utils/grid_search.py +++ /dev/null @@ -1,60 +0,0 @@ -import copy -from itertools import product - - -class Params: - def __init__(self, config, params): - self._initial_config = copy.deepcopy(config) - self._initial_params = copy.deepcopy(params) - - def __iter__(self): - keys = [] - values = [] - - all_keys = set(self._initial_config.keys()).union( - set(self._initial_params.keys()), - ) - - for field_name in all_keys: - keys.append(field_name) - - initial_field_value = self._initial_config.get(field_name) - params_fields_value = self._initial_params.get(field_name) - - if initial_field_value: - if ( - params_fields_value is None - ): # We don't want to iterate through this field - values.append([initial_field_value]) - elif isinstance(initial_field_value, list) and isinstance( - initial_field_value, - list, - ): - assert len(initial_field_value) == len(params_fields_value) - list_values = [] - for i in range(len(initial_field_value)): - field_variations = list( - Params( - initial_field_value[i], - params_fields_value[i], - ), - ) - list_values.append(field_variations) - list_values = [p for p in product(*list_values)] - values.append(list_values) - elif isinstance( - initial_field_value, - dict, - ): # It is composite param, need to go inside - field_variations = list( - Params(initial_field_value, params_fields_value), - ) - values.append(field_variations) - else: # Simple param, can take as it is - values.append([initial_field_value] + params_fields_value) - else: - values.append(self._initial_params[field_name]) - - yield from [dict(zip(keys, p)) for p in product(*values)] - - return StopIteration diff --git a/src/irec/utils/registry.py b/src/irec/utils/registry.py deleted file mode 100644 index 5aeb85b3..00000000 --- a/src/irec/utils/registry.py +++ /dev/null @@ -1,79 +0,0 @@ -import inspect - - -class MetaParent(type): - def __init__(cls, name, base, params, **kwargs): - super().__init__(name, base, params) - is_base_class = cls.mro()[1] is object - if is_base_class: - base_class = cls - else: - base_class_found = False - for key in cls.mro(): - if isinstance(key, MetaParent) and key.mro()[1] is object: - assert ( - base_class_found is False - ), 'multiple base classes(bug)' - base_class = key - base_class_found = True - assert base_class_found is True, f'no base class for {name}' - - if is_base_class: - cls._subclasses = {} - - @classmethod - def __init_subclass__(scls, config_name=None): - super().__init_subclass__() - if config_name is not None: - if config_name in base_class._subclasses: - raise ValueError( - 'Class with name `{}` is already registered'.format( - config_name, - ), - ) - scls.config_name = config_name - base_class._subclasses[config_name] = scls - - cls.__init_subclass__ = __init_subclass__ - - @classmethod - def parent_create_from_config(cls, config, **kwargs): - if 'type' in config: - return cls._subclasses[config['type']].create_from_config( - config, - **kwargs, - ) - else: - raise ValueError( - 'There is no `type` provided for the `{}` class'.format( - name, - ), - ) - - # Take kwargs for the last initialized baseclass - init_kwargs = {} - for bcls in cls.mro()[:-1]: # Look into all base classes except object - if '__init__' not in bcls.__dict__: - continue - init_kwargs = inspect.signature(bcls.__init__).parameters - break - - @classmethod - def child_create_from_config(cls, config, **kwargs): - kwargs = {} - for key, argspec in init_kwargs.items(): - if key == 'self': - continue - value = config.get(key, argspec.default) - if value is inspect.Parameter.empty: - msg = 'There is no value for `{}.__init__` required field `{}` in config `{}`' - raise ValueError(msg.format(cls, key, config)) - kwargs[key] = value - return cls(**kwargs) - - if 'create_from_config' not in cls.__dict__: - cls.create_from_config = ( - parent_create_from_config - if is_base_class - else child_create_from_config - ) diff --git a/src/irec/utils/tensorboards/__init__.py b/src/irec/utils/tensorboards/__init__.py deleted file mode 100644 index d0a0aa75..00000000 --- a/src/irec/utils/tensorboards/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .tensorboard_writers import ( - TensorboardWriter, - GLOBAL_TENSORBOARD_WRITER, - LOGS_DIR, -) - -__all__ = [ - 'TensorboardWriter', - 'GLOBAL_TENSORBOARD_WRITER', - 'LOGS_DIR', -] diff --git a/src/irec/utils/tensorboards/tensorboard_writers.py b/src/irec/utils/tensorboards/tensorboard_writers.py deleted file mode 100644 index 82e78c2c..00000000 --- a/src/irec/utils/tensorboards/tensorboard_writers.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -import time -import datetime - -from torch.utils.tensorboard import SummaryWriter - -LOGS_DIR = './tensorboard_logs' -GLOBAL_TENSORBOARD_WRITER = None - - -class TensorboardWriter(SummaryWriter): - def __init__(self, experiment_name, use_time=True): - self._experiment_name = experiment_name - super().__init__( - log_dir=os.path.join( - LOGS_DIR, - f'{experiment_name}_{datetime.datetime.now().strftime("%Y-%m-%dT%H:%M" if use_time else "")}', - ), - ) - - def add_scalar(self, *args, **kwargs): - super().add_scalar(*args, **kwargs) - - -class TensorboardTimer: - def __init__(self, scope): - super().__init__(LOGS_DIR) - self._scope = scope - - def __enter__(self): - self.start = int(time.time() * 10000) - return self - - def __exit__(self, *args): - self.end = int(time.time() * 10000) - interval = (self.end - self.start) / 10.0 - GLOBAL_TENSORBOARD_WRITER.add_scalar(self._scope, interval) From 6279a2da20869f754ff8647f1d7a58daa83e0e4b Mon Sep 17 00:00:00 2001 From: Noname Untitled Date: Thu, 13 Nov 2025 00:08:31 +0300 Subject: [PATCH 2/5] Speed up tiger and sasrec models --- src/irec/callbacks/__init__.py | 3 + src/irec/callbacks/infer.py | 52 +++++++ src/irec/models/__init__.py | 3 +- src/irec/models/base.py | 27 +++- src/irec/models/flashattn.py | 233 +++++++++++++++++++++++++++++ src/irec/models/old_rqvae.py | 136 ----------------- src/irec/models/rqvae.py | 260 --------------------------------- 7 files changed, 316 insertions(+), 398 deletions(-) create mode 100644 src/irec/callbacks/infer.py create mode 100644 src/irec/models/flashattn.py delete mode 100644 src/irec/models/old_rqvae.py delete mode 100644 src/irec/models/rqvae.py diff --git a/src/irec/callbacks/__init__.py b/src/irec/callbacks/__init__.py index 1111adc7..cdc6e8f4 100644 --- a/src/irec/callbacks/__init__.py +++ b/src/irec/callbacks/__init__.py @@ -3,6 +3,7 @@ from irec.callbacks.base import Callback, BatchCallback, Composite +from irec.callbacks.infer import InferenceSaver from irec.callbacks.logging import Logger, LoggingCallback, TensorboardLogger from irec.callbacks.metrics import BatchMetrics, LambdaMetrics, MetricAccumulator, Accumulator, MeanAccumulator, Validation from irec.callbacks.model import LoadModel @@ -19,6 +20,8 @@ 'TrainingCallback', 'Composite', + 'InferenceSaver', + 'Logger', 'LoggingCallback', 'TensorboardLogger', diff --git a/src/irec/callbacks/infer.py b/src/irec/callbacks/infer.py new file mode 100644 index 00000000..937b001f --- /dev/null +++ b/src/irec/callbacks/infer.py @@ -0,0 +1,52 @@ +import json +import numpy as np +import torch +import pickle + +from irec.callbacks.train import TrainingCallback +from irec.runners.train import TrainingRunner, TrainingRunnerContext + + +class InferenceSaver(TrainingCallback): + def __init__(self, metrics, save_path, format='pickle'): + super().__init__() + self._metrics = metrics + self._save_path = save_path + self._format = format + self._accumulated_result = list() + assert format in ['pickle', 'json'], 'Unknown inference format!' + + def state_dict(self): + return {'accumulate_result': self._accumulated_result} + + def load_state_dict(self, state_dict): + self._accumulated_result = state_dict['accumulate_result'] + + def before_run(self, runner: TrainingRunner): + return super().before_run(runner) + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + batch_result = self._metrics(context.batch, context.model_outputs, context.metrics) + processed_batch_result = {} + for key, values in batch_result.items(): + if isinstance(values, torch.Tensor): + processed_batch_result[key] = values.tolist() + elif isinstance(values, np.ndarray): + processed_batch_result[key] = values.tolist() + else: + assert isinstance(values, list) + processed_batch_result[key] = values + + self._accumulated_result.extend([ + {key: values[i] for key, values in processed_batch_result.items()} + for i in range(len(next(iter(processed_batch_result.values())))) + ]) + + def after_run(self, runner: TrainingRunner, context: TrainingRunnerContext): + if self._format == 'pickle': + with open(self._save_path, 'wb') as f: + pickle.dump(self._accumulated_result, f, protocol=pickle.HIGHEST_PROTOCOL) + + if self._format == 'json': + with open(self._save_path, 'w') as f: + json.dump(self._accumulated_result, f, indent=2) diff --git a/src/irec/models/__init__.py b/src/irec/models/__init__.py index 8c8013de..841af87c 100644 --- a/src/irec/models/__init__.py +++ b/src/irec/models/__init__.py @@ -1,7 +1,8 @@ -from irec.models.base import create_masked_tensor, TorchModel +from irec.models.base import create_masked_tensor, AutoCast, TorchModel __all__ = [ 'create_masked_tensor', + 'AutoCast', 'TorchModel', ] diff --git a/src/irec/models/base.py b/src/irec/models/base.py index 12699210..b97cade8 100644 --- a/src/irec/models/base.py +++ b/src/irec/models/base.py @@ -61,4 +61,29 @@ def _init_weights(self, initializer_range): b=2 * initializer_range, ) else: - raise ValueError(f'Unknown transformer weight: {key}') \ No newline at end of file + raise ValueError(f'Unknown transformer weight: {key}') + + +class AutoCast(nn.Module): + def __init__(self, module, dtype=torch.bfloat16, device_type='cuda', cache_enabled=True): + super().__init__() + self.module = module + self._dtype = dtype + self._device_type = device_type + self._cache_enabled = cache_enabled + + @property + def dtype(self): + return self._dtype + + @property + def cache_enabled(self): + return self._cache_enabled + + def forward(self, *args, **kwargs): + with torch.autocast( + device_type=self._device_type, + dtype=self._dtype, + cache_enabled=self._cache_enabled + ): + return self.module(*args, **kwargs) diff --git a/src/irec/models/flashattn.py b/src/irec/models/flashattn.py new file mode 100644 index 00000000..1870767b --- /dev/null +++ b/src/irec/models/flashattn.py @@ -0,0 +1,233 @@ +from typing import Any, Callable, Optional, Self, Union + +import einops +import torch +import torch.nn.functional as F + +from flash_attn.modules.mha import FlashSelfAttention + + +class MHAttention(torch.nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: Optional[int] = None, + dropout: float = 0.0, + window_size: tuple[int, int] = (-1, -1), + return_residual: bool = True, + causal: bool = False, + ): + super().__init__() + + self.embedding_dim = embedding_dim + self.dropout = dropout + self.window_size = window_size + self.return_residual = return_residual + self.causal = causal + + self.num_heads = num_heads or embedding_dim // 64 + assert self.embedding_dim % self.num_heads == 0 + self.head_dim = self.embedding_dim // self.num_heads + + self.self_attention = FlashSelfAttention( + causal=self.causal, + attention_dropout=self.dropout, + window_size=self.window_size, + ) + + self.Wqkv = torch.nn.Linear(self.embedding_dim, 3 * self.head_dim * self.num_heads) + + self.out_proj = torch.nn.Linear(self.embedding_dim, self.embedding_dim) + + def forward( + self, + x: torch.Tensor, + lengths: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor, + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + qkv = self.Wqkv(x) + + qkv = einops.rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim) + + result = self.self_attention(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + result = self.out_proj(einops.rearrange(result, '... h d -> ... (h d)')) + + return (result, x) if self.return_residual else result + + +class MLP(torch.nn.Module): + ACTIVATIONS = {'relu': F.relu, 'sigmoid': F.sigmoid, 'gelu': F.gelu, 'swiglu': F.silu} + + def __init__( + self, + in_features: int, + dropout: float = 0.0, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + activation: Union[str, Callable[[torch.Tensor], torch.Tensor]] = F.relu, + bias1: bool = True, + bias2: bool = True, + return_residual: bool = False, + ): + super().__init__() + + self.activation = MLP.ACTIVATIONS[activation] if isinstance(activation, str) else activation + + out_features = out_features or in_features + hidden_features = hidden_features or in_features * 4 + + self.return_residual = return_residual + + self.fc1 = torch.nn.Linear(in_features, hidden_features, bias=bias1) + self.fc2 = torch.nn.Linear(hidden_features, out_features, bias=bias2) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + y = self.fc1(x) + y = self.activation(y) + y = self.dropout(y) + y = self.fc2(y) + return (y, x) if self.return_residual else y + + +class Block(torch.nn.Module): + NORMALIZATIONS = {'layer_norm': torch.nn.LayerNorm, 'rms_norm': torch.nn.RMSNorm} + + def __init__( + self, + mixer: torch.nn.Module, + mlp: torch.nn.Module, + dropout1: torch.nn.Module, + norm1: torch.nn.Module, + norm2: torch.nn.Module, + ): + super().__init__() + + self.mixer = mixer + self.mlp = mlp + + self.dropout1 = dropout1 + self.norm1 = norm1 + self.norm2 = norm2 + + def forward( + self, + hidden_states: torch.Tensor, + mixer_kwargs: Optional[dict[str, Any]] = None, + ) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + if mixer_kwargs is None: + mixer_kwargs = {} + + mixed = self.mixer(hidden_states, **mixer_kwargs) + dropped = self.dropout1(mixed) + residual = hidden_states + dropped + hidden_states = self.norm1(residual) + + mlped = self.mlp(hidden_states) + hidden_states = hidden_states + mlped + hidden_states = self.norm2(residual) + + return hidden_states + + @staticmethod + def make_normalization(norm: str, **kwargs): + return Block.NORMALIZATIONS[norm](kwargs['embedding_dim'], eps=kwargs['eps']) + + @classmethod + def make_default( + cls: type[Self], + embedding_dim: int, + dim_feedforward: int, + num_heads: Optional[int] = None, + dropout: float = 0.0, + norm: str = 'layer_norm', + activation: Union[str, Callable[[torch.Tensor], torch.Tensor]] = F.relu, + causal: bool = False, + eps: float = 1e-5, + attn_dropout: float = 0.0, + window_size: tuple[int, int] = (-1, -1), + ) -> Self: + + norm1 = Block.make_normalization(norm, embedding_dim=embedding_dim, eps=eps) + norm2 = Block.make_normalization(norm, embedding_dim=embedding_dim, eps=eps) + + mixer = MHAttention( + embedding_dim=embedding_dim, + dropout=attn_dropout, + return_residual=False, + causal=causal, + num_heads=num_heads, + window_size=window_size, + ) + + mlp = MLP(in_features=embedding_dim, dropout=dropout, hidden_features=dim_feedforward, return_residual=False, activation=activation) + + block = cls( + mixer=mixer, + mlp=mlp, + dropout1=torch.nn.Dropout(dropout), + norm1=norm1, + norm2=norm2, + ) + + return block + + +class TransformerEncoder(torch.nn.Module): + def __init__( + self, + embedding_dim: int, + dim_feedforward: int, + layers: Union[int, list[Block]], + num_heads: Optional[int] = None, + dropout: Union[float, tuple[float, float]] = 0.0, + norm: str = 'layer_norm', + activation: Union[str, Callable[[torch.Tensor], torch.Tensor]] = F.relu, + causal: bool = False, + attn_dropout: float = 0.0, + window_size: tuple[int, int] = (-1, -1), + ): + super().__init__() + + if isinstance(layers, int): + layers = [ + Block.make_default( + embedding_dim=embedding_dim, + dim_feedforward=dim_feedforward, + dropout=dropout, + norm=norm, + causal=causal, + activation=activation, + num_heads=num_heads, + attn_dropout=attn_dropout, + window_size=window_size, + ) + for _ in range(layers) + ] + + self._embedding_dim = embedding_dim + self.layers = torch.nn.ModuleList(layers) + + @property + def embedding_dim(self): + return self._embedding_dim + + def forward( + self, + embeddings: torch.Tensor, + lengths: torch.Tensor, + max_seqlen: Optional[int] = None, + **mixer_kwargs + ) -> torch.Tensor: + cu_seqlens = F.pad(torch.cumsum(lengths, dim=0, dtype=torch.int32), (1, 0)) + + if max_seqlen is None: + max_seqlen = lengths.max().item() + + mixer_kwargs.update({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, 'lengths': lengths}) + + for layer in self.layers: + embeddings = layer(embeddings, mixer_kwargs=mixer_kwargs) + + return embeddings diff --git a/src/irec/models/old_rqvae.py b/src/irec/models/old_rqvae.py deleted file mode 100644 index 86c1d583..00000000 --- a/src/irec/models/old_rqvae.py +++ /dev/null @@ -1,136 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ResidualBlock(nn.Module): - def __init__(self, input_dim, output_dim, dropout=0.1): - super().__init__() - self.input_dim = input_dim - self.output_dim = output_dim - - self.norm = nn.LayerNorm(input_dim) - self.layer = nn.Linear(input_dim, output_dim) - self.act = nn.GELU() - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - embedding = x - embedding = self.norm(embedding) - embedding = self.layer(embedding) - embedding = self.act(embedding) - embedding = self.dropout(embedding) - - if self.input_dim == self.output_dim: - return embedding + x - return embedding - - -class Tower(nn.Module): - def __init__(self, dims, dropout): - super().__init__() - self.layers = nn.ModuleList() - for i in range(len(dims) - 1): - self.layers.append(ResidualBlock(dims[i], dims[i + 1], dropout)) - - def forward(self, x): - embedding = x - for layer in self.layers: - embedding = layer(embedding) - return embedding - - -class RQVAE(nn.Module): - def __init__( - self, - input_dim, - num_codebooks, - codebook_size, - embedding_dim, - layers, - dropout_prob=0.0, - beta=0.25, - quant_loss_weight=1.0, - - ): - super().__init__() - - self.input_dim = input_dim - self.num_codebooks = num_codebooks - self.codebook_size = codebook_size - self.embedding_dim = embedding_dim - self.beta = beta - self.quant_loss_weight = quant_loss_weight - - self.layers = layers - self.dropout_prob = dropout_prob - - self.encoder_layer_dims = [self.input_dim] + self.layers + [self.embedding_dim] - self.decoder_layer_dims = self.encoder_layer_dims[::-1] - - # TODO add inizialisation with AE - self.encoder = Tower( - dims=self.encoder_layer_dims, - dropout=self.dropout_prob - ) - self.decoder = Tower( - dims=self.decoder_layer_dims, - dropout=self.dropout_prob - ) - - self.codebooks = torch.nn.ParameterList() - for _ in range(num_codebooks): - cb = torch.FloatTensor(codebook_size, embedding_dim) - self.codebooks.append(cb) - - @staticmethod - def make_encoding_tower(d1, d2, bias=False): - return torch.nn.Sequential( - nn.LayerNorm(d1), - nn.Linear(d1, d1), - nn.GELU(), - torch.nn.Linear(d1, d2, bias=bias) - ) - - @staticmethod - def get_codebook_indices(remainder, codebook): - dist = torch.cdist(remainder, codebook) - return dist.argmin(dim=-1) - - def forward(self, inputs): - latent_vector = self.encoder(inputs['embedding']) - - latent_restored = 0 - rqvae_loss = 0 - clusters = [] - remainder = latent_vector - for codebook in self.codebooks: - codebook_indices = self.get_codebook_indices(remainder, codebook) - clusters.append(codebook_indices) - - quantized = codebook[codebook_indices] - codebook_vectors = remainder + (quantized - remainder).detach() - - rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) - rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) - - latent_restored += codebook_vectors - remainder = remainder - codebook_vectors - - embeddings_restored = self.decoder(latent_restored) - recon_loss = F.mse_loss(embeddings_restored, inputs['embedding']) - loss = (recon_loss + self.quant_loss_weight * rqvae_loss).mean() - - clusters_counts = [] - for cluster in clusters: - clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) - - return loss, { - 'loss': loss.item(), - 'recon_loss': recon_loss.mean().item(), - 'rqvae_loss': rqvae_loss.mean().item(), - - 'clusters_counts': clusters_counts, - 'clusters': torch.stack(clusters).T, - 'embedding_hat': embeddings_restored, - } diff --git a/src/irec/models/rqvae.py b/src/irec/models/rqvae.py deleted file mode 100644 index 5cdfe784..00000000 --- a/src/irec/models/rqvae.py +++ /dev/null @@ -1,260 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class ResidualBlock(nn.Module): - def __init__(self, input_dim, output_dim, dropout=0.1): - super().__init__() - self.input_dim = input_dim - self.output_dim = output_dim - - self.norm = nn.LayerNorm(input_dim) - self.layer = nn.Linear(input_dim, output_dim) - self.act = nn.GELU() - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - embedding = x - embedding = self.norm(embedding) - embedding = self.layer(embedding) - embedding = self.act(embedding) - embedding = self.dropout(embedding) - - if self.input_dim == self.output_dim: - return embedding + x - return embedding - - -class Tower(nn.Module): - def __init__(self, dims, dropout): - super().__init__() - self.layers = nn.ModuleList() - for i in range(len(dims) - 1): - self.layers.append(ResidualBlock(dims[i], dims[i + 1], dropout)) - - def forward(self, x): - embedding = x - for layer in self.layers: - embedding = layer(embedding) - return embedding - - -class VectorQuantizer(nn.Module): - - def __init__( - self, - codebook_size, - embedding_dim, - mu=0.25, - ): - super().__init__() - self.codebook_size = codebook_size - self.embedding_dim = embedding_dim - self.mu = mu - - self.embedding = nn.Embedding(self.codebook_size, self.embedding_dim) - - def get_codebook(self): - return self.embedding.weight - - def forward(self, latent_embeddings): - # Get closest centroids - d = torch.sum(latent_embeddings**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t() - 2 * torch.matmul(latent_embeddings, self.embedding.weight.t()) - indices = torch.argmin(d, dim=-1) - - x_q = self.embedding(indices) - - # compute loss for embedding - commitment_loss = F.mse_loss(x_q.detach(), latent_embeddings) - codebook_loss = F.mse_loss(x_q, latent_embeddings.detach()) - - quantization_loss = codebook_loss + self.mu * commitment_loss - - # preserve gradients - x_q = latent_embeddings + (x_q - latent_embeddings).detach() - - indices = indices.view(latent_embeddings.shape[:-1]) - - return x_q, quantization_loss, indices - - -class ResidualVectorQuantizer(nn.Module): - def __init__( - self, - num_codebooks, - codebook_size, - embedding_dim, - ): - super().__init__() - self.num_codebooks = num_codebooks - self.codebook_size = codebook_size - self.embedding_dim = embedding_dim - - self.vq_layers: list[VectorQuantizer] = nn.ModuleList([ - VectorQuantizer(codebook_size, embedding_dim) for _ in range(num_codebooks) - ]) - - def forward(self, latent_embeddings): - all_losses = [] - all_indices = [] - - x_q = 0 - residual = latent_embeddings - - for quantizer in self.vq_layers: - x_res, loss, indices = quantizer(residual) - residual = residual - x_res - x_q = x_q + x_res - - all_losses.append(loss) - all_indices.append(indices) - - mean_losses = torch.stack(all_losses).mean() - all_indices = torch.stack(all_indices, dim=-1) - - return x_q, mean_losses, all_indices - - -class RQVAE(nn.Module): - def __init__( - self, - input_dim, - num_codebooks, - codebook_size, - embedding_dim, - layers, - dropout_prob=0.0, - beta=0.25, - quant_loss_weight=1.0, - cf_loss_weight=1.0, - cf_embeddings=None - ): - super().__init__() - - self.input_dim = input_dim - self.num_codebooks = num_codebooks - self.codebook_size = codebook_size - self.embedding_dim = embedding_dim - self.beta = beta - self.quant_loss_weight = quant_loss_weight - - self.layers = layers - self.dropout_prob = dropout_prob - self.cf_embeddings = cf_embeddings - self.cf_loss_weight = cf_loss_weight - - self.encoder_layer_dims = [self.input_dim] + self.layers + [self.embedding_dim] - self.decoder_layer_dims = self.encoder_layer_dims[::-1] - - # TODO add inizialisation with AE - self.encoder = Tower( - dims=self.encoder_layer_dims, - dropout=self.dropout_prob - ) - self.decoder = Tower( - dims=self.decoder_layer_dims, - dropout=self.dropout_prob - ) - - self.rq = ResidualVectorQuantizer( - num_codebooks=num_codebooks, - codebook_size=codebook_size, - embedding_dim=embedding_dim - ) - - @staticmethod - def get_codebook_indices(remainder, quantizer): - dist = torch.sum(remainder**2, dim=1, keepdim=True) + torch.sum(quantizer.embedding.weight**2, dim=1, keepdim=True).t() - 2 * torch.matmul(remainder, quantizer.embedding.weight.t()) - return dist.argmin(dim=-1) - - def forward(self, inputs): - latent_vector = self.encoder(inputs['embedding']) - - latent_restored = 0 - rqvae_loss = 0 - clusters = [] - remainder = latent_vector - for quantizer in self.rq.vq_layers: - codebook_indices = self.get_codebook_indices(remainder, quantizer) - clusters.append(codebook_indices) - - quantized = quantizer.embedding(codebook_indices) - codebook_vectors = remainder + (quantized - remainder).detach() - - rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) - rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) - - # codebook_vectors, quantizer_loss, codebook_indices = quantizer(remainder) - # rqvae_loss += quantizer_loss - - latent_restored += codebook_vectors - remainder = remainder - codebook_vectors - - embeddings_restored = self.decoder(latent_restored) - recon_loss = F.mse_loss(embeddings_restored, inputs['embedding']) - - # TODO for now - # if self.cf_embeddings is not None: - # cf_embedding_in_batch = self.cf_embeddings[item_ids] - # cf_embedding_in_batch = torch.from_numpy(cf_embedding_in_batch).to(quantized_embeddings.device) - # cf_loss = self.CF_loss(quantized_embeddings, cf_embedding_in_batch) - # else: - cf_loss = torch.as_tensor(0.0) - - loss = (recon_loss + self.quant_loss_weight * rqvae_loss + self.cf_loss_weight * cf_loss).mean() - - clusters_counts = [] - for cluster in clusters: - clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) - - # loss, recon_loss, cf_loss, rq_loss = self.compute_loss( - # content_embeddings=content_embeddings, - # out_embeddings=out_embeddings, - # item_ids=item_ids, - # rq_loss=rq_loss, - # quantized_embeddings=quantized_embeddings - # ) - - return loss, { - 'loss': loss.item(), - 'recon_loss': recon_loss.mean().item(), - 'rqvae_loss': rqvae_loss.mean().item(), - 'cf_loss': cf_loss.item(), - - 'clusters_counts': clusters_counts, - 'clusters': torch.stack(clusters).T, - 'embedding_hat': embeddings_restored, - } - - # def CF_loss(self, quantized_rep, encoded_rep): - # batch_size = quantized_rep.size(0) - # labels = torch.arange(batch_size, dtype=torch.long, device=quantized_rep.device) - # similarities = quantized_rep @ encoded_rep.T - # cf_loss = F.cross_entropy(similarities, labels) - # return cf_loss - - # @torch.no_grad() - # def get_indices(self, content_embeddings): - # latent_embeddings = self.encoder(content_embeddings) - # _, _, indices = self.rq(latent_embeddings) - # return indices - - # def compute_loss(self, content_embeddings, out_embeddings, item_ids, rq_loss, quantized_embeddings): - # if self.loss_type == 'mse': - # recon_loss = F.mse_loss(content_embeddings, out_embeddings, reduction='mean') - # elif self.loss_type == 'l1': - # recon_loss = F.l1_loss(content_embeddings, out_embeddings, reduction='mean') - # else: - # raise ValueError('incompatible loss type') - - # if self.cf_embeddings is not None: - # cf_embedding_in_batch = self.cf_embeddings[item_ids] - # cf_embedding_in_batch = torch.from_numpy(cf_embedding_in_batch).to(quantized_embeddings.device) - # cf_loss = self.CF_loss(quantized_embeddings, cf_embedding_in_batch) - # else: - # cf_loss = torch.as_tensor(0.0) - - # total_loss = recon_loss + self.quant_loss_weight * rq_loss + self.cf_loss_weight * cf_loss - - # return total_loss, recon_loss, cf_loss, rq_loss \ No newline at end of file From 4bf251a2104ceabcb0439607ead21c6ba5695b87 Mon Sep 17 00:00:00 2001 From: Noname Untitled Date: Thu, 13 Nov 2025 00:09:14 +0300 Subject: [PATCH 3/5] Add sasrec & tiger --- scripts/sasrec/data.py | 234 ++++++++++++++++++++++++++++++++++ scripts/sasrec/models.py | 169 ++++++++++++++++++++++++ scripts/sasrec/train.py | 204 +++++++++++++++++++++++++++++ scripts/sasrec/varka.py | 94 ++++++++++++++ scripts/tiger/data.py | 250 ++++++++++++++++++++++++++++++++++++ scripts/tiger/models.py | 181 ++++++++++++++++++++++++++ scripts/tiger/train.py | 205 ++++++++++++++++++++++++++++++ scripts/tiger/varka.py | 268 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 1605 insertions(+) create mode 100644 scripts/sasrec/data.py create mode 100644 scripts/sasrec/models.py create mode 100644 scripts/sasrec/train.py create mode 100644 scripts/sasrec/varka.py create mode 100644 scripts/tiger/data.py create mode 100644 scripts/tiger/models.py create mode 100644 scripts/tiger/train.py create mode 100644 scripts/tiger/varka.py diff --git a/scripts/sasrec/data.py b/scripts/sasrec/data.py new file mode 100644 index 00000000..9e07dcac --- /dev/null +++ b/scripts/sasrec/data.py @@ -0,0 +1,234 @@ +import json +from loguru import logger +import numpy as np +from pathlib import Path + +import pyarrow.feather as feather + +import torch + +from irec.data.base import BaseDataset + + +class Dataset: + def __init__( + self, + train_sampler, + validation_sampler, + test_sampler, + num_items, + max_sequence_length + ): + self._train_sampler = train_sampler + self._validation_sampler = validation_sampler + self._test_sampler = test_sampler + self._num_items = num_items + self._max_sequence_length = max_sequence_length + + @classmethod + def create(cls, inter_json_path, max_sequence_length, sampler_type, is_extended=False): + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + with open(inter_json_path, 'r') as f: + user_interactions = json.load(f) + + for user_id_str, item_ids in user_interactions.items(): + user_id = int(user_id_str) + + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + + assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items' + + # sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] (leave one out scheme, 8 - train, 9 - valid, 10 - test) + if is_extended: + # sample = [1, 2] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + # sample = [1, 2, 3, 4, 5, 6, 7] + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + for prefix_length in range(2, len(item_ids) - 2 + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:prefix_length], + }) + else: + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-2], + }) + + # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9] + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-1], + }) + + # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids, + }) + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + logger.debug(f'Max item id: {max_item_id}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + + def get_datasets(self): + return self._train_sampler, self._validation_sampler, self._test_sampler + + @property + def num_items(self): + return self._num_items + + @property + def max_sequence_length(self): + return self._max_sequence_length + + +class TrainDataset(BaseDataset): + def __init__(self, dataset, prediction_type, max_sequence_length): + self._dataset = dataset + self._prediction_type = prediction_type + self._max_sequence_length = max_sequence_length + + self._transforms = { + 'sasrec': self._all_items_transform, + 'tiger': self._last_item_transform + } + + def _all_items_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item_sequence = sample['item.ids'][-self._max_sequence_length:][1:] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array(next_item_sequence, dtype=np.int64), + 'labels.length': np.array([len(next_item_sequence)], dtype=np.int64) + } + + def _last_item_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + last_item = sample['item.ids'][-self._max_sequence_length:][-1] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([last_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + } + + def __getitem__(self, index): + return self._transforms[self._prediction_type](self._dataset[index]) + + def __len__(self): + return len(self._dataset) + + +class EvalDataset(BaseDataset): + def __init__(self, dataset, max_sequence_length): + self._dataset = dataset + self._max_sequence_length = max_sequence_length + + @property + def dataset(self): + return self._dataset + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, index): + sample = self._dataset[index] + + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item = sample['item.ids'][-self._max_sequence_length:][-1] + + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([next_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + 'visited.ids': np.array(sample['item.ids'][:-1], dtype=np.int64), + 'visited.length': np.array([len(sample['item.ids'][:-1])], dtype=np.int64), + } + + +class ArrowBatchDataset(BaseDataset): + def __init__(self, batch_dir, device='cuda', preload=False): + self.batch_dir = Path(batch_dir) + self.device = device + all_files = list(self.batch_dir.glob('batch_*_len_*.arrow')) + + from collections import defaultdict + batch_files_map = defaultdict(list) + + for f in all_files: + batch_id = int(f.stem.split('_')[1]) + batch_files_map[batch_id].append(f) + + for batch_id in batch_files_map: + batch_files_map[batch_id].sort() + + self.batch_indices = sorted(batch_files_map.keys()) + if preload: + print(f"Preloading {len(self.batch_indices)} batches...") + self.cached_batches = [] + + for idx in range(len(self.batch_indices)): + batch_id = self.batch_indices[idx] + arrow_files = batch_files_map[batch_id] + + batch = {} + for arrow_file in arrow_files: + table = feather.read_table(arrow_file) + + for column_name in table.column_names: + arr = table[column_name].to_numpy() + batch[column_name] = torch.from_numpy(arr.copy()).to(device) + + self.cached_batches.append(batch) + else: + self.cached_batches = None + self.batch_files_map = batch_files_map + + def __len__(self): + return len(self.batch_indices) + + def __getitem__(self, idx): + if self.cached_batches is not None: + return self.cached_batches[idx] + else: + batch_id = self.batch_indices[idx] + arrow_files = self.batch_files_map[batch_id] + + batch = {} + for arrow_file in arrow_files: + table = feather.read_table(arrow_file) + + for column_name in table.column_names: + arr = table[column_name].to_numpy() + batch[column_name] = torch.from_numpy(arr.copy()).to(self.device) + + return batch diff --git a/scripts/sasrec/models.py b/scripts/sasrec/models.py new file mode 100644 index 00000000..defa7f1d --- /dev/null +++ b/scripts/sasrec/models.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn + +from irec.models import TorchModel, create_masked_tensor + +from irec.models.flashattn import TransformerEncoder + + +class SasRecModel(TorchModel): + def __init__( + self, + num_items, + max_sequence_length, + embedding_dim, + num_heads, + num_layers, + dim_feedforward, + activation, + topk_k, + dropout=0.0, + layer_norm_eps=1e-9, + initializer_range=0.02 + ): + super().__init__() + self._num_items = num_items + self._num_heads = num_heads + self._embedding_dim = embedding_dim + + self._item_embeddings = nn.Embedding( + num_embeddings=num_items, + embedding_dim=embedding_dim + ) + self._position_embeddings = nn.Embedding( + num_embeddings=max_sequence_length, + embedding_dim=embedding_dim + ) + + self._topk_k = topk_k + + self._encoder = TransformerEncoder( + embedding_dim=embedding_dim, + dim_feedforward=dim_feedforward, + layers=num_layers, + num_heads=num_heads, + dropout=dropout, + activation=activation, + causal=True, + ) + + self._init_weights(initializer_range) + + def forward(self, inputs): + all_sample_events = inputs['item.ids'] # (total_batch_items) + all_sample_lengths = inputs['item.length'] # (batch_size) + max_seqlen = int(all_sample_lengths.max().item()) + + embeddings = self._item_embeddings(all_sample_events) + + end_indices = all_sample_lengths.cumsum(dim=0) # (batch_size) + start_indices = end_indices - all_sample_lengths # (batch_size) + + sample_indices = torch.arange( + all_sample_lengths.shape[0], + device=all_sample_lengths.device + ).repeat_interleave(all_sample_lengths) # (total_batch_items) + + positions = torch.arange( + all_sample_events.shape[0], + device=all_sample_events.device + ) - start_indices[sample_indices] # (total_batch_items) + + position_embeddings = self._position_embeddings(positions) # (total_batch_items, embedding_dim) + + embeddings = embeddings + position_embeddings # (total_batch_items, embedding_dim) + + all_sample_embeddings = self._encoder(embeddings=embeddings, lengths=all_sample_lengths, max_seqlen=max_seqlen) # (total_batch_items, embedding_dim) + + all_positive_sample_events = inputs['labels.ids'] # (total_batch_items) + + if not self.training: + offsets = torch.cumsum(all_sample_lengths, dim=-1) + all_sample_embeddings = all_sample_embeddings[offsets - 1] + + all_embeddings = self._item_embeddings.weight # (num_items, embedding_dim) + + # a -- total_batch_items, n -- num_items, d -- embedding_dim + all_scores = torch.einsum( + 'ad,nd->an', + all_sample_embeddings, + all_embeddings + ) # (total_batch_items, num_items) + + positive_scores = torch.gather( + input=all_scores, + dim=1, + index=all_positive_sample_events[..., None] + )[:, 0] # (total_batch_items) + + # Compute loss + negative_scores = torch.gather( + input=all_scores, + dim=1, + index=torch.randint( + low=0, + high=all_scores.shape[1], + size=all_positive_sample_events.shape, + device=all_positive_sample_events.device + )[..., None] + )[:, 0] # (total_batch_items) + + with torch.autocast(device_type='cuda', enabled=False): + loss = self._compute_loss( + positive_scores.float(), + negative_scores.float() + ) + + metrics = { + 'loss': loss.detach() + } + + if not self.training: + batch_size = all_sample_lengths.shape[0] + num_items = all_embeddings.shape[0] + + padded_items, _ = create_masked_tensor( + data=all_sample_events, + lengths=all_sample_lengths, + ) # (batch_size, max_seq_len) + + visited_mask = torch.zeros( + batch_size, num_items, + dtype=torch.bool, + device=all_sample_events.device + ) + + batch_indices = torch.arange(batch_size, device=all_sample_events.device)[:, None] + batch_indices = batch_indices.expand(-1, padded_items.shape[1]) + + visited_mask.scatter_( + dim=1, + index=padded_items.long(), + value=True + ) + + all_scores = all_scores.masked_fill(visited_mask, float('-inf')) + + positive_position = (all_scores > positive_scores[:, None]).float().sum(dim=-1) # (batch_size или total_batch_items) + dcg_score = 1. / (torch.log2(positive_position + 1) + 1.) + + for k in [5, 10, 20]: + metrics[f'recall@{k}'] = (positive_position < k).float() + metrics[f'ndcg@{k}'] = torch.where( + positive_position < k, + dcg_score, + torch.zeros_like(dcg_score) + ).float() + + return loss, metrics + + def _compute_loss(self, positive_scores, negative_scores): + assert positive_scores.shape[0] == negative_scores.shape[0] + + loss = torch.nn.functional.binary_cross_entropy_with_logits( + positive_scores, torch.ones_like(positive_scores) + ) + torch.nn.functional.binary_cross_entropy_with_logits( + negative_scores, torch.zeros_like(negative_scores) + ) + + return loss diff --git a/scripts/sasrec/train.py b/scripts/sasrec/train.py new file mode 100644 index 00000000..21de6367 --- /dev/null +++ b/scripts/sasrec/train.py @@ -0,0 +1,204 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.models import AutoCast +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +from data import ArrowBatchDataset +from models import SasRecModel + +SEED_VALUE = 42 +DEVICE = 'cuda' + +EXPERIMENT_NAME = 'sasrec_beauty' +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SiZE = 256 +EMBEDDING_DIM = 64 +NUM_HEADS = 2 +NUM_LAYERS = 2 +FEEDFORWARD_DIM = 256 +DROPOUT = 0.3 +LR = 1e-4 + +NUM_ITEMS = 12101 + +IREC_PATH = '../../' + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + train_dataloader = DataLoader( + ArrowBatchDataset( + os.path.join(IREC_PATH, 'data/Beauty/sasrec_train/'), + device='cpu', + preload=None + ), + batch_size=1, + shuffle=True, + num_workers=16, + prefetch_factor=16, + pin_memory=True, + persistent_workers=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)) + + valid_dataloder = ArrowBatchDataset( + os.path.join(IREC_PATH, 'data/Beauty/sasrec_valid/'), + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + os.path.join(IREC_PATH, 'data/Beauty/sasrec_eval/'), + device=DEVICE, + preload=True + ) + + model = SasRecModel( + num_items=NUM_ITEMS, + max_sequence_length=MAX_SEQ_LEN, + embedding_dim=EMBEDDING_DIM, + num_heads=NUM_HEADS, + num_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + activation='relu', + topk_k=20, + dropout=DROPOUT, + layer_norm_eps=1e-8, + initializer_range=0.02 + ) + model = torch.compile(model, mode="default", fullgraph=False) + model = model.to('cuda') + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recall@5': cb.MeanAccumulator(), + 'train/recall@10': cb.MeanAccumulator(), + 'train/recall@20': cb.MeanAccumulator(), + 'train/ndcg@5': cb.MeanAccumulator(), + 'train/ndcg@10': cb.MeanAccumulator(), + 'train/ndcg@20': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='validation/ndcg@20', + patience=40, + minimize=False, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=os.path.join(IREC_PATH, 'tensorboard_logs') + # ), + # cb.StopAfterNumSteps(40) + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=AutoCast(model, dtype=torch.bfloat16, device_type=DEVICE), + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/sasrec/varka.py b/scripts/sasrec/varka.py new file mode 100644 index 00000000..1b4b97cb --- /dev/null +++ b/scripts/sasrec/varka.py @@ -0,0 +1,94 @@ +from collections import defaultdict +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate +from irec.data.dataloader import DataLoader + +from data import Dataset + + +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SiZE = 256 + +IREC_PATH = '../../' + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + length_groups[length][key] = value + + for length, fields in length_groups.items(): + arrow_dict = {k: pa.array(v) for k, v in fields.items()} + table = pa.table(arrow_dict) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + +def main(): + + data = Dataset.create( + inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter.json'), + max_sequence_length=MAX_SEQ_LEN, + sampler_type='sasrec', + is_extended=False + ) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ).map(Collate()).repeat(NUM_EPOCHS) + + valid_dataloder = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SiZE, + shuffle=False, + drop_last=False + ).map(Collate()) + + eval_dataloder = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SiZE, + shuffle=False, + drop_last=False + ).map(Collate()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, os.path.join(IREC_PATH, 'data/Beauty/sasrec_train/')) + + valid_batches = [] + for valid_batch in valid_dataloder: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, os.path.join(IREC_PATH, 'data/Beauty/sasrec_valid/')) + + eval_batches = [] + for eval_batch in eval_dataloder: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, os.path.join(IREC_PATH, 'data/Beauty/sasrec_eval/')) + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger/data.py b/scripts/tiger/data.py new file mode 100644 index 00000000..188993a1 --- /dev/null +++ b/scripts/tiger/data.py @@ -0,0 +1,250 @@ +from collections import defaultdict +import json +from loguru import logger +import numpy as np +from pathlib import Path + + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.base import BaseDataset + + +class Dataset: + def __init__( + self, + train_sampler, + validation_sampler, + test_sampler, + num_items, + max_sequence_length + ): + self._train_sampler = train_sampler + self._validation_sampler = validation_sampler + self._test_sampler = test_sampler + self._num_items = num_items + self._max_sequence_length = max_sequence_length + + @classmethod + def create(cls, inter_json_path, max_sequence_length, sampler_type, is_extended=False): + max_item_id = 0 + train_dataset, validation_dataset, test_dataset = [], [], [] + + with open(inter_json_path, 'r') as f: + user_interactions = json.load(f) + + for user_id_str, item_ids in user_interactions.items(): + user_id = int(user_id_str) + + if item_ids: + max_item_id = max(max_item_id, max(item_ids)) + + assert len(item_ids) >= 5, f'Core-5 dataset is used, user {user_id} has only {len(item_ids)} items' + + # sequence = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] (leave one out scheme, 8 - train, 9 - valid, 10 - test) + if is_extended: + # sample = [1, 2] + # sample = [1, 2, 3] + # sample = [1, 2, 3, 4] + # sample = [1, 2, 3, 4, 5] + # sample = [1, 2, 3, 4, 5, 6] + # sample = [1, 2, 3, 4, 5, 6, 7] + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + for prefix_length in range(2, len(item_ids) - 2 + 1): + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:prefix_length], + }) + else: + # sample = [1, 2, 3, 4, 5, 6, 7, 8] + train_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-2], + }) + + # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9] + validation_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids[:-1], + }) + + # sample = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + test_dataset.append({ + 'user.ids': [user_id], + 'item.ids': item_ids, + }) + + logger.debug(f'Train dataset size: {len(train_dataset)}') + logger.debug(f'Validation dataset size: {len(validation_dataset)}') + logger.debug(f'Test dataset size: {len(test_dataset)}') + logger.debug(f'Max item id: {max_item_id}') + + train_sampler = TrainDataset(train_dataset, sampler_type, max_sequence_length=max_sequence_length) + validation_sampler = EvalDataset(validation_dataset, max_sequence_length=max_sequence_length) + test_sampler = EvalDataset(test_dataset, max_sequence_length=max_sequence_length) + + return cls( + train_sampler=train_sampler, + validation_sampler=validation_sampler, + test_sampler=test_sampler, + num_items=max_item_id + 1, # +1 added because our ids are 0-indexed + max_sequence_length=max_sequence_length + ) + + def get_datasets(self): + return self._train_sampler, self._validation_sampler, self._test_sampler + + @property + def num_items(self): + return self._num_items + + @property + def max_sequence_length(self): + return self._max_sequence_length + + +class TrainDataset(BaseDataset): + def __init__(self, dataset, prediction_type, max_sequence_length): + self._dataset = dataset + self._prediction_type = prediction_type + self._max_sequence_length = max_sequence_length + + self._transforms = { + 'sasrec': self._all_items_transform, + 'tiger': self._last_item_transform + } + + def _all_items_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item_sequence = sample['item.ids'][-self._max_sequence_length:][1:] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array(next_item_sequence, dtype=np.int64), + 'labels.length': np.array([len(next_item_sequence)], dtype=np.int64) + } + + def _last_item_transform(self, sample): + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + last_item = sample['item.ids'][-self._max_sequence_length:][-1] + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([last_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + } + + def __getitem__(self, index): + return self._transforms[self._prediction_type](self._dataset[index]) + + def __len__(self): + return len(self._dataset) + + +class EvalDataset(BaseDataset): + def __init__(self, dataset, max_sequence_length): + self._dataset = dataset + self._max_sequence_length = max_sequence_length + + @property + def dataset(self): + return self._dataset + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, index): + sample = self._dataset[index] + + item_sequence = sample['item.ids'][-self._max_sequence_length:][:-1] + next_item = sample['item.ids'][-self._max_sequence_length:][-1] + + return { + 'user.ids': np.array(sample['user.ids'], dtype=np.int64), + 'user.length': np.array([len(sample['user.ids'])], dtype=np.int64), + 'item.ids': np.array(item_sequence, dtype=np.int64), + 'item.length': np.array([len(item_sequence)], dtype=np.int64), + 'labels.ids': np.array([next_item], dtype=np.int64), + 'labels.length': np.array([1], dtype=np.int64), + 'visited.ids': np.array(sample['item.ids'][:-1], dtype=np.int64), + 'visited.length': np.array([len(sample['item.ids'][:-1])], dtype=np.int64), + } + + +class ArrowBatchDataset(BaseDataset): + def __init__(self, batch_dir, device='cuda', preload=False): + self.batch_dir = Path(batch_dir) + self.device = device + + all_files = list(self.batch_dir.glob('batch_*_len_*.arrow')) + + batch_files_map = defaultdict(list) + for f in all_files: + batch_id = int(f.stem.split('_')[1]) + batch_files_map[batch_id].append(f) + + for batch_id in batch_files_map: + batch_files_map[batch_id].sort() + + self.batch_indices = sorted(batch_files_map.keys()) + + if preload: + print(f"Preloading {len(self.batch_indices)} batches...") + self.cached_batches = [] + + for idx in range(len(self.batch_indices)): + batch = self._load_batch(batch_files_map[self.batch_indices[idx]]) + self.cached_batches.append(batch) + else: + self.cached_batches = None + self.batch_files_map = batch_files_map + + def _load_batch(self, arrow_files): + batch = {} + + for arrow_file in arrow_files: + table = feather.read_table(arrow_file) + metadata = table.schema.metadata or {} + + for col_name in table.column_names: + col = table.column(col_name) + + shape_key = f'{col_name}_shape' + dtype_key = f'{col_name}_dtype' + + if shape_key.encode() in metadata: + shape = eval(metadata[shape_key.encode()].decode()) + dtype = np.dtype(metadata[dtype_key.encode()].decode()) + + # Проверяем тип колонки + if pa.types.is_list(col.type) or pa.types.is_large_list(col.type): + arr = np.array(col.to_pylist(), dtype=dtype) + else: + arr = col.to_numpy().reshape(shape).astype(dtype) + else: + if pa.types.is_list(col.type) or pa.types.is_large_list(col.type): + arr = np.array(col.to_pylist()) + else: + arr = col.to_numpy() + + batch[col_name] = torch.from_numpy(arr.copy()).to(self.device) + + return batch + + def __len__(self): + return len(self.batch_indices) + + def __getitem__(self, idx): + if self.cached_batches is not None: + return self.cached_batches[idx] + else: + batch_id = self.batch_indices[idx] + arrow_files = self.batch_files_map[batch_id] + return self._load_batch(arrow_files) diff --git a/scripts/tiger/models.py b/scripts/tiger/models.py new file mode 100644 index 00000000..4fa837ca --- /dev/null +++ b/scripts/tiger/models.py @@ -0,0 +1,181 @@ +import torch +from transformers import T5ForConditionalGeneration, T5Config, LogitsProcessor + +from irec.models import TorchModel + + +class CorrectItemsLogitsProcessor(LogitsProcessor): + def __init__(self, num_codebooks, codebook_size, mapping, num_beams, visited_items): + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.num_beams = num_beams + + semantic_ids = [] + for i in range(len(mapping)): + assert len(mapping[str(i)]) == num_codebooks, 'All semantic ids must have the same length' + semantic_ids.append(mapping[str(i)]) + + self.index_semantic_ids = torch.tensor(semantic_ids, dtype=torch.long, device=visited_items.device) # (num_items, semantic_ids) + + batch_size, _ = visited_items.shape + + self.index_semantic_ids = torch.tile(self.index_semantic_ids[None], dims=[batch_size, 1, 1]) # (batch_size, num_items, semantic_ids) + + index = visited_items[..., None].tile(dims=[1, 1, num_codebooks]) # (batch_size, num_rated, semantic_ids) + self.index_semantic_ids = torch.scatter( + input=self.index_semantic_ids, + dim=1, + index=index, + src=torch.zeros_like(index) + ) # (batch_size, num_items, semantic_ids) + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + next_sid_codebook_num = (torch.minimum((input_ids[:, -1].max() // self.codebook_size), torch.as_tensor(self.num_codebooks - 1)).item() + 1) % self.num_codebooks + a = torch.tile(self.index_semantic_ids[:, None, :, next_sid_codebook_num], dims=[1, self.num_beams, 1]) # (batch_size, num_beams, num_items) + a = a.reshape(a.shape[0] * a.shape[1], a.shape[2]) # (batch_size * num_beams, num_items) + + if next_sid_codebook_num != 0: + b = torch.tile(self.index_semantic_ids[:, None :, :next_sid_codebook_num], dims=[1, self.num_beams, 1, 1]) # (batch_size, num_beams, num_items, sid_len) + b = b.reshape(b.shape[0] * b.shape[1], b.shape[2], b.shape[3]) # (batch_size * num_beams, num_items, sid_len) + + current_prefixes = input_ids[:, -next_sid_codebook_num:] # (batch_size * num_beams, sid_len) + possible_next_items_mask = ( + torch.eq(current_prefixes[:, None, :], b).long().sum(dim=-1) == next_sid_codebook_num + ) # (batch_size * num_beams, num_items) + a[~possible_next_items_mask] = (next_sid_codebook_num + 1) * self.codebook_size + + scores_mask = torch.zeros_like(scores).bool() # (batch_size * num_beams, num_items) + scores_mask = torch.scatter_add( + input=scores_mask, + dim=-1, + index=a, + src=torch.ones_like(a).bool() + ) + + scores[:, :next_sid_codebook_num * self.codebook_size] = -torch.inf + scores[:, (next_sid_codebook_num + 1) * self.codebook_size:] = -torch.inf + scores[~(scores_mask.bool())] = -torch.inf + + return scores + + +class TigerModel(TorchModel): + def __init__( + self, + embedding_dim, + codebook_size, + sem_id_len, + num_positions, + user_ids_count, + num_heads, + num_encoder_layers, + num_decoder_layers, + dim_feedforward, + num_beams=100, + num_return_sequences=20, + d_kv=64, + layer_norm_eps=1e-6, + activation='relu', + dropout=0.1, + initializer_range=0.02, + logits_processor=None + ): + super().__init__() + self._embedding_dim = embedding_dim + self._codebook_size = codebook_size + self._num_positions = num_positions + self._num_heads = num_heads + self._num_encoder_layers = num_encoder_layers + self._num_decoder_layers = num_decoder_layers + self._dim_feedforward = dim_feedforward + self._num_beams = num_beams + self._num_return_sequences = num_return_sequences + self._d_kv = d_kv + self._layer_norm_eps = layer_norm_eps + self._activation = activation + self._dropout = dropout + self._sem_id_len = sem_id_len + self.user_ids_count = user_ids_count + self.logits_processor = logits_processor + + unified_vocab_size = codebook_size * self._sem_id_len + self.user_ids_count + 10 # 10 for utilities + self.config = T5Config( + vocab_size=unified_vocab_size, + d_model=self._embedding_dim, + d_kv=self._d_kv, + d_ff=self._dim_feedforward, + num_layers=self._num_encoder_layers, + num_decoder_layers=self._num_decoder_layers, + num_heads=self._num_heads, + dropout_rate=self._dropout, + is_encoder_decoder=True, + use_cache=False, + pad_token_id=unified_vocab_size - 1, + eos_token_id=unified_vocab_size - 2, + decoder_start_token_id=unified_vocab_size - 3, + layer_norm_epsilon=self._layer_norm_eps, + feed_forward_proj=self._activation, + tie_word_embeddings=False + ) + self.model = T5ForConditionalGeneration(config=self.config) + self._init_weights(initializer_range) + + self.model = torch.compile( + self.model, + mode='reduce-overhead', + fullgraph=False, + dynamic=True + ) + + def forward(self, inputs): + input_semantic_ids = inputs['input.data'] + attention_mask = inputs['input.mask'] + target_semantic_ids = inputs['output.data'] + + decoder_input_ids = target_semantic_ids[:, :-1].contiguous() + labels = target_semantic_ids[:, 1:].contiguous() + + model_output = self.model( + input_ids=input_semantic_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + labels=labels + ) + loss = model_output['loss'] + + metrics = {'loss': loss.detach()} + + if not self.training: + visited_batch = inputs['visited.padded'] + + output = self.model.generate( + input_ids=input_semantic_ids, + attention_mask=attention_mask, + num_beams=self._num_beams, + num_return_sequences=self._num_return_sequences, + max_length=self._sem_id_len + 1, + decoder_start_token_id=self.config.decoder_start_token_id, + eos_token_id=self.config.eos_token_id, + pad_token_id=self.config.pad_token_id, + do_sample=False, + early_stopping=False, + logits_processor=[self.logits_processor(visited_items=visited_batch)] if self.logits_processor is not None else [], + ) + + predictions = output[:, 1:].reshape(-1, self._num_return_sequences, self._sem_id_len) + + all_hits = (torch.eq(predictions, labels[:, None]).sum(dim=-1)) # (batch_size, top_k) + for k in [5, 10, 20]: + hits = (all_hits[:, :k] == self._sem_id_len).float() # (batch_size, k) + recall = hits.sum(dim=-1) # (batch_size) + discount_factor = 1 / torch.log2(torch.arange(1, k + 1, 1).float() + 1.).to(hits.device) # (k) + + metrics[f'recall@{k}'] = recall.cpu().float() + metrics[f'ndcg@{k}'] = torch.einsum('bk,k->b', hits, discount_factor).cpu().float() + + # for prefix_length in range(1, self._sem_id_len + 1): + # metrics[f'correct_prefix_{prefix_length}_recall@{self._num_return_sequences}'] = ( + # torch.eq(predictions[:, :k, :prefix_length], labels[:, None, :prefix_length]).long().sum(dim=-1) == prefix_length + # ).float().sum(dim=-1).mean(dim=-1).item() / self._num_return_sequences + + return loss, metrics diff --git a/scripts/tiger/train.py b/scripts/tiger/train.py new file mode 100644 index 00000000..3cd6fd99 --- /dev/null +++ b/scripts/tiger/train.py @@ -0,0 +1,205 @@ +from functools import partial +import json +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.transforms import Collate, ToDevice +from irec.data.dataloader import DataLoader +from irec.runners import TrainingRunner +from irec.utils import fix_random_seed + +from data import ArrowBatchDataset +from models import TigerModel, CorrectItemsLogitsProcessor + +SEED_VALUE = 42 +DEVICE = 'cuda' + +EXPERIMENT_NAME = 'tiger_beauty' +NUM_EPOCHS = 300 +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +EMBEDDING_DIM = 128 +CODEBOOK_SIZE = 256 +NUM_POSITIONS = 20 +NUM_USER_HASH = 2000 +NUM_HEADS = 6 +NUM_LAYERS = 4 +FEEDFORWARD_DIM = 1024 +KV_DIM = 64 +DROPOUT = 0.1 +NUM_BEAMS = 30 +TOP_K = 20 +NUM_CODEBOOKS = 4 +LR = 3e-4 + +IREC_PATH = '../../' + +torch.set_float32_matmul_precision('high') +torch._dynamo.config.capture_scalar_outputs = True + +import torch._inductor.config as config +config.triton.cudagraph_skip_dynamic_graphs = True + + +def main(): + fix_random_seed(SEED_VALUE) + + with open(os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json'), 'r') as f: + mappings = json.load(f) + + train_dataloader = DataLoader( + ArrowBatchDataset( + os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/'), + device='cpu', + preload=True + ), + batch_size=1, + shuffle=True, + num_workers=0, + pin_memory=True, + collate_fn=Collate() + ).map(ToDevice(DEVICE)).repeat(NUM_EPOCHS) + + valid_dataloder = ArrowBatchDataset( + os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/'), + device=DEVICE, + preload=True + ) + + eval_dataloder = ArrowBatchDataset( + os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/'), + device=DEVICE, + preload=True + ) + + model = TigerModel( + embedding_dim=EMBEDDING_DIM, + codebook_size=CODEBOOK_SIZE, + sem_id_len=NUM_CODEBOOKS, + user_ids_count=NUM_USER_HASH, + num_positions=NUM_POSITIONS, + num_heads=NUM_HEADS, + num_encoder_layers=NUM_LAYERS, + num_decoder_layers=NUM_LAYERS, + dim_feedforward=FEEDFORWARD_DIM, + num_beams=NUM_BEAMS, + num_return_sequences=TOP_K, + activation='relu', + d_kv=KV_DIM, + dropout=DROPOUT, + layer_norm_eps=1e-6, + initializer_range=0.02, + logits_processor=partial( + CorrectItemsLogitsProcessor, + NUM_CODEBOOKS, + CODEBOOK_SIZE, + mappings, + NUM_BEAMS + ) + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.AdamW( + model.parameters(), + lr=LR, + ) + + EPOCH_NUM_STEPS = 1024 # int(len(train_dataloader) // NUM_EPOCHS) + + callbacks = [ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + }, name='train'), + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + }, + reset_every_num_steps=EPOCH_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: model_outputs, name='validation'), + cb.MetricAccumulator( + accumulators={ + 'validation/loss': cb.MeanAccumulator(), + 'validation/recall@5': cb.MeanAccumulator(), + 'validation/recall@10': cb.MeanAccumulator(), + 'validation/recall@20': cb.MeanAccumulator(), + 'validation/ndcg@5': cb.MeanAccumulator(), + 'validation/ndcg@10': cb.MeanAccumulator(), + 'validation/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Validation( + dataset=eval_dataloder, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='eval'), + cb.MetricAccumulator( + accumulators={ + 'eval/loss': cb.MeanAccumulator(), + 'eval/recall@5': cb.MeanAccumulator(), + 'eval/recall@10': cb.MeanAccumulator(), + 'eval/recall@20': cb.MeanAccumulator(), + 'eval/ndcg@5': cb.MeanAccumulator(), + 'eval/ndcg@10': cb.MeanAccumulator(), + 'eval/ndcg@20': cb.MeanAccumulator(), + }, + ), + ], + ).every_num_steps(EPOCH_NUM_STEPS), + + cb.Logger().every_num_steps(EPOCH_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='eval/ndcg@20', + patience=40, + minimize=False, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(EPOCH_NUM_STEPS) + + # cb.Profiler( + # wait=10, + # warmup=10, + # active=10, + # logdir=os.path.join(IREC_PATH, 'tensorboard_logs') + # ), + # cb.StopAfterNumSteps(40) + + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/tiger/varka.py b/scripts/tiger/varka.py new file mode 100644 index 00000000..ed475953 --- /dev/null +++ b/scripts/tiger/varka.py @@ -0,0 +1,268 @@ +from collections import defaultdict +import json +import murmurhash +import numpy as np +import os +from pathlib import Path + +import pyarrow as pa +import pyarrow.feather as feather + +import torch + +from irec.data.transforms import Collate, Transform +from irec.data.dataloader import DataLoader + +from data import Dataset + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + + +MAX_SEQ_LEN = 20 +TRAIN_BATCH_SIZE = 256 +VALID_BATCH_SIZE = 1024 +NUM_USER_HASH = 2000 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 4 + +UNIFIED_VOCAB_SIZE = CODEBOOK_SIZE * NUM_CODEBOOKS + NUM_USER_HASH + 10 # 10 for utilities +PAD_TOKEN_ID = UNIFIED_VOCAB_SIZE - 1, +EOS_TOKEN_ID = UNIFIED_VOCAB_SIZE - 2, +DECODER_START_TOKEN_ID = UNIFIED_VOCAB_SIZE - 3, + + +IREC_PATH = '../../' + + +class TigerProcessing(Transform): + def __call__(self, batch): + input_semantic_ids, attention_mask = batch['item.semantic.padded'], batch['item.semantic.mask'] + batch_size = attention_mask.shape[0] + + input_semantic_ids[~attention_mask] = PAD_TOKEN_ID # TODO ??? + + input_semantic_ids = np.concat([ + input_semantic_ids, + NUM_CODEBOOKS * CODEBOOK_SIZE + batch['user.hashed.ids'][:, None] + ], axis=-1) + + attention_mask = np.concat([ + attention_mask, + np.ones((batch_size, 1), dtype=attention_mask.dtype) + ], axis=-1) + + batch['input.data'] = input_semantic_ids + batch['input.mask'] = attention_mask + + target_semantic_ids = batch['labels.semantic.padded'] + target_semantic_ids = np.concat([ + np.ones( + (batch_size, 1), + dtype=np.int64, + ) * DECODER_START_TOKEN_ID, + target_semantic_ids + ], axis=-1) + + batch['output.data'] = target_semantic_ids + + return batch + + +class ToMasked(Transform): + def __init__(self, prefix, is_right_aligned=False): + self._prefix = prefix + self._is_right_aligned = is_right_aligned + + def __call__(self, batch): + data = batch[f'{self._prefix}.ids'] + lengths = batch[f'{self._prefix}.length'] + + batch_size = lengths.shape[0] + max_sequence_length = int(lengths.max()) + + if len(data.shape) == 1: # only indices + padded_tensor = np.zeros( + (batch_size, max_sequence_length), + dtype=data.dtype + ) # (batch_size, max_seq_len) + else: + assert len(data.shape) == 2 # embeddings + padded_tensor = np.zeros( + (batch_size, max_sequence_length, data.shape[-1]), + dtype=data.dtype + ) # (batch_size, max_seq_len, emb_dim) + + mask = np.arange(max_sequence_length)[None] < lengths[:, None] + + if self._is_right_aligned: + mask = np.flip(mask, axis=-1) + + padded_tensor[mask] = data + + batch[f'{self._prefix}.padded'] = padded_tensor + batch[f'{self._prefix}.mask'] = mask + + return batch + + +class SemanticIdsMapper(Transform): + def __init__(self, mapping, names=[]): + super().__init__() + self._mapping = mapping + self._names = names + + data = [] + for i in range(len(mapping)): + data.append(mapping[str(i)]) + self._mapping_tensor = torch.tensor(data, dtype=torch.long) + self._semantic_length = self._mapping_tensor.shape[-1] + + def __call__(self, batch): + for name in self._names: + if f'{name}.ids' in batch: + ids = batch[f'{name}.ids'] + lengths = batch[f'{name}.length'] + assert ids.min() >= 0 + assert ids.max() < self._mapping_tensor.shape[0] + batch[f'{name}.semantic.ids'] = self._mapping_tensor[ids].flatten().numpy() + batch[f'{name}.semantic.length'] = lengths * self._semantic_length + + return batch + + +class UserHashing(Transform): + def __init__(self, hash_size): + super().__init__() + self._hash_size = hash_size + + def __call__(self, batch): + batch['user.hashed.ids'] = np.array([murmurhash.hash(str(x)) % self._hash_size for x in batch['user.ids']], dtype=np.int64) + return batch + + +def save_batches_to_arrow(batches, output_dir): + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=False) + + for batch_idx, batch in enumerate(batches): + length_groups = defaultdict(dict) + metadata_groups = defaultdict(dict) + + for key, value in batch.items(): + length = len(value) + + metadata_groups[length][f'{key}_shape'] = str(value.shape) + metadata_groups[length][f'{key}_dtype'] = str(value.dtype) + + if value.ndim == 1: + # 1D массив - сохраняем как есть + length_groups[length][key] = value + elif value.ndim == 2: + # 2D массив - используем list of lists + length_groups[length][key] = value.tolist() + else: + # >2D массив - flatten и сохраняем shape + length_groups[length][key] = value.flatten() + + for length, fields in length_groups.items(): + arrow_dict = {} + for k, v in fields.items(): + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], list): + # List of lists (2D) + arrow_dict[k] = pa.array(v) + else: + arrow_dict[k] = pa.array(v) + + table = pa.table(arrow_dict) + if length in metadata_groups: + table = table.replace_schema_metadata(metadata_groups[length]) + + feather.write_feather( + table, + output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + compression='lz4' + ) + + # arrow_dict = {k: pa.array(v) for k, v in fields.items()} + # table = pa.table(arrow_dict) + + # feather.write_feather( + # table, + # output_dir / f"batch_{batch_idx:06d}_len_{length}.arrow", + # compression='lz4' + # ) + + +def main(): + data = Dataset.create( + inter_json_path=os.path.join(IREC_PATH, 'data/Beauty/inter.json'), + max_sequence_length=MAX_SEQ_LEN, + sampler_type='tiger', + is_extended=True + ) + + with open(os.path.join(IREC_PATH, 'results/rqvae_beauty_best_clusters_colisionless.json'), 'r') as f: + mappings = json.load(f) + + train_dataset, valid_dataset, eval_dataset = data.get_datasets() + + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=TRAIN_BATCH_SIZE, + shuffle=True, + drop_last=True + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(TigerProcessing()) + + valid_dataloader = DataLoader( + dataset=valid_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=VALID_BATCH_SIZE, + shuffle=False, + drop_last=False + ) \ + .map(Collate()) \ + .map(UserHashing(NUM_USER_HASH)) \ + .map(SemanticIdsMapper(mappings, names=['item', 'labels'])) \ + .map(ToMasked('item.semantic', is_right_aligned=True)) \ + .map(ToMasked('labels.semantic', is_right_aligned=True)) \ + .map(ToMasked('visited', is_right_aligned=True)) \ + .map(TigerProcessing()) + + train_batches = [] + for train_batch in train_dataloader: + train_batches.append(train_batch) + save_batches_to_arrow(train_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_train_batches/')) + + valid_batches = [] + for valid_batch in valid_dataloader: + valid_batches.append(valid_batch) + save_batches_to_arrow(valid_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_valid_batches/')) + + eval_batches = [] + for eval_batch in eval_dataloader: + eval_batches.append(eval_batch) + save_batches_to_arrow(eval_batches, os.path.join(IREC_PATH, 'data/Beauty/tiger_eval_batches/')) + + +if __name__ == '__main__': + main() From 081ad1cbd7de977e8de4bf36174009980458cecc Mon Sep 17 00:00:00 2001 From: Noname Untitled Date: Mon, 17 Nov 2025 15:43:07 +0300 Subject: [PATCH 4/5] Add rqvae --- scripts/rqvae/callbacks.py | 64 ++++ scripts/rqvae/data.py | 57 ++++ scripts/rqvae/infer_best.py | 132 +++++++ scripts/rqvae/infer_new.py | 100 ++++++ scripts/rqvae/models.py | 548 ++++++++++++++++++++++++++++++ scripts/rqvae/train_best.py | 149 ++++++++ scripts/rqvae/train_best_pairs.py | 156 +++++++++ scripts/rqvae/train_new.py | 218 ++++++++++++ scripts/rqvae/train_old.py | 149 ++++++++ 9 files changed, 1573 insertions(+) create mode 100644 scripts/rqvae/callbacks.py create mode 100644 scripts/rqvae/data.py create mode 100644 scripts/rqvae/infer_best.py create mode 100644 scripts/rqvae/infer_new.py create mode 100644 scripts/rqvae/models.py create mode 100644 scripts/rqvae/train_best.py create mode 100644 scripts/rqvae/train_best_pairs.py create mode 100644 scripts/rqvae/train_new.py create mode 100644 scripts/rqvae/train_old.py diff --git a/scripts/rqvae/callbacks.py b/scripts/rqvae/callbacks.py new file mode 100644 index 00000000..43ec460a --- /dev/null +++ b/scripts/rqvae/callbacks.py @@ -0,0 +1,64 @@ +import torch + +import irec.callbacks as cb +from irec.runners import TrainingRunner, TrainingRunnerContext + +class InitCodebooks(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + @torch.no_grad() + def before_run(self, runner: TrainingRunner): + for i in range(len(runner.model.codebooks)): + X = next(iter(self._dataloader))['embedding'] + idx = torch.randperm(X.shape[0], device=X.device)[:len(runner.model.codebooks[i])] + remainder = runner.model.encoder(X[idx]) + + for j in range(i): + codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.codebooks[j]) + codebook_vectors = runner.model.codebooks[j][codebook_indices] + remainder = remainder - codebook_vectors + + runner.model.codebooks[i].data = remainder.detach() + + +class FixDeadCentroids(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)): + context.metrics[f'num_dead/{i}'] = num_fixed + + @torch.no_grad() + def fix_dead_codebooks(self, runner: TrainingRunner): + num_fixed = [] + for codebook_idx, codebook in enumerate(runner.model.codebooks): + centroid_counts = torch.zeros(codebook.shape[0], dtype=torch.long, device=codebook.device) + random_batch = next(iter(self._dataloader))['embedding'] + + for batch in self._dataloader: + remainder = runner.model.encoder(batch['embedding']) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + + indices = runner.model.get_codebook_indices(remainder, codebook) + centroid_counts.scatter_add_(0, indices, torch.ones_like(indices)) + + dead_mask = (centroid_counts == 0) + num_dead = int(dead_mask.sum().item()) + num_fixed.append(num_dead) + if num_dead == 0: + continue + + remainder = runner.model.encoder(random_batch) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.codebooks[l]) + remainder = remainder - runner.model.codebooks[l][ind] + remainder = remainder[torch.randperm(remainder.shape[0], device=codebook.device)][:num_dead] + codebook[dead_mask] = remainder.detach() + + return num_fixed diff --git a/scripts/rqvae/data.py b/scripts/rqvae/data.py new file mode 100644 index 00000000..972f8e59 --- /dev/null +++ b/scripts/rqvae/data.py @@ -0,0 +1,57 @@ +import numpy as np +import pickle + +from irec.data.base import BaseDataset +from irec.data.transforms import Transform + + +class EmbeddingDataset(BaseDataset): + def __init__(self, data_path): + self.data_path = data_path + with open(data_path, 'rb') as f: + self.data = pickle.load(f) + + self.item_ids = np.array(self.data['item_id'], dtype=np.int64) + self.embeddings = np.array(self.data['embedding'], dtype=np.float32) + + def __getitem__(self, idx): + index = self.item_ids[idx] + tensor_emb = self.embeddings[idx] + return { + 'item_id': index, + 'embedding': tensor_emb, + 'embedding_dim': len(tensor_emb) + } + + def __len__(self): + return len(self.embeddings) + + +# class PairEmbeddingDataset(BaseDataset): +# def __init__(self, data_path): +# self.data_path = data_path +# with open(data_path, 'rb') as f: +# self.data = pickle.load(f) + +# for num in ['fst', 'snd']: +# setattr(self, f'{num}_item_id', np.array(self.data[f'{num}_item_id'], dtype=np.int64)) +# setattr(self, f'{num}_embedding', np.array(self.data[f'{num}_embedding'], dtype=np.float32)) + +# def __getitem__(self, idx): +# result = {} + +# for key in ['fst_item_id', 'fst_embedding', 'snd_item_id', 'snd_embedding']: +# result[key] = self.__getattribute__(key)[idx] + +# return result + + +class ProcessEmbeddings(Transform): + def __init__(self, embedding_dim, keys): + self.embedding_dim = embedding_dim + self.keys = keys + + def __call__(self, batch): + for key in self.keys: + batch[key] = batch[key].reshape(-1, self.embedding_dim) + return batch \ No newline at end of file diff --git a/scripts/rqvae/infer_best.py b/scripts/rqvae/infer_best.py new file mode 100644 index 00000000..3b2ade15 --- /dev/null +++ b/scripts/rqvae/infer_best.py @@ -0,0 +1,132 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDataset, ProcessEmbeddings +from models import BestRQVAE + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 + +BETA = 0.25 +MODEL_PATH = 'rqvae_beauty_best_8_1_1_cf_best_0.0079.pth' + +EXPERIMENT_NAME = 'rqvae_beauty_best_8_1_1_cf' +IREC_PATH = '../../' + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDataset( + data_path=os.path.join(IREC_PATH, 'data/Beauty/content_embeddings_train_only_0.8_0.1_0.1.pkl') + ) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + cf_embeddings = torch.load( + os.path.join(IREC_PATH, 'results', 'sasrec_beauty_8_1_1_32d_item_embeddings.pt'), + map_location='cpu' + )['weight'] + + model = BestRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + cf_loss_weight=1.0, + cf_embeddings=cf_embeddings + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(os.path.join(IREC_PATH, 'checkpoints', MODEL_PATH)), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + import json + from collections import defaultdict + import numpy as np + + with open(os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters.json'), 'r') as f: + mappings = json.load(f) + + inter = {} + sem_2_ids = defaultdict(list) + for mapping in mappings: + item_id = mapping['item_id'] + clusters = mapping['clusters'] + inter[int(item_id)] = clusters + sem_2_ids[tuple(clusters)].append(int(item_id)) + + for semantics, items in sem_2_ids.items(): + assert len(items) <= CODEBOOK_SIZE, str(len(items)) + collision_solvers = np.random.permutation(CODEBOOK_SIZE)[:len(items)].tolist() + for item_id, collision_solver in zip(items, collision_solvers): + inter[item_id].append(collision_solver) + for i in range(len(inter[item_id])): + inter[item_id][i] += CODEBOOK_SIZE * i + + with open(os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters_colisionless.json'), 'w') as f: + json.dump(inter, f, indent=2) + + +if __name__ == '__main__': + main() diff --git a/scripts/rqvae/infer_new.py b/scripts/rqvae/infer_new.py new file mode 100644 index 00000000..75e3d54b --- /dev/null +++ b/scripts/rqvae/infer_new.py @@ -0,0 +1,100 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import InferenceRunner + +from irec.utils import fix_random_seed + +from data import EmbeddingDataset, process_embeddings +from models import NewRQVAE + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 + +BETA = 0.25 +MODEL_PATH = 'rqvae_beauty_new_best_0.0131.pth' +EXPERIMENT_NAME = 'rqvae_beauty_new' +IREC_PATH = '../../' + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDataset( + data_path=os.path.join(IREC_PATH, 'data/Beauty/content_embeddings.pkl') + ) + + dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(process_embeddings) + + model = NewRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + layers=[2048, 1024, 512, 256, 128], + dropout_prob=0.1, + beta=BETA, + quant_loss_weight=1.0, + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + callbacks = [ + cb.LoadModel(os.path.join(IREC_PATH, 'checkpoints', MODEL_PATH)), + + cb.BatchMetrics(metrics=lambda model_outputs, _: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + }, name='valid'), + + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + }, + ), + + cb.Logger().every_num_steps(len(dataloader)), + + cb.InferenceSaver( + metrics=lambda batch, model_outputs, _: {'item_id': batch['item_id'], 'clusters': model_outputs['clusters']}, + save_path=os.path.join(IREC_PATH, 'results', f'{EXPERIMENT_NAME}_clusters.json'), + format='json' + ) + ] + + logger.debug('Everything is ready for training process!') + + runner = InferenceRunner( + model=model, + dataset=dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/rqvae/models.py b/scripts/rqvae/models.py new file mode 100644 index 00000000..ca69d8b0 --- /dev/null +++ b/scripts/rqvae/models.py @@ -0,0 +1,548 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, input_dim, output_dim, dropout=0.1): + super().__init__() + self.input_dim = input_dim + self.output_dim = output_dim + + self.norm = nn.LayerNorm(input_dim) + self.layer = nn.Linear(input_dim, output_dim) + self.act = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + embedding = x + embedding = self.norm(embedding) + embedding = self.layer(embedding) + embedding = self.act(embedding) + embedding = self.dropout(embedding) + + if self.input_dim == self.output_dim: + return embedding + x + return embedding + + +class Tower(nn.Module): + def __init__(self, dims, dropout): + super().__init__() + self.layers = nn.ModuleList() + for i in range(len(dims) - 1): + self.layers.append(ResidualBlock(dims[i], dims[i + 1], dropout)) + + def forward(self, x): + embedding = x + for layer in self.layers: + embedding = layer(embedding) + return embedding + + +class OldRQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + layers, + dropout_prob=0.0, + beta=0.25, + quant_loss_weight=1.0, + + ): + super().__init__() + self.register_buffer('beta', torch.tensor(beta)) + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.quant_loss_weight = quant_loss_weight + + self.layers = layers + self.dropout_prob = dropout_prob + + self.encoder_layer_dims = [self.input_dim] + self.layers + [self.embedding_dim] + self.decoder_layer_dims = self.encoder_layer_dims[::-1] + + # TODO add inizialisation with AE + # self.encoder = Tower( + # dims=self.encoder_layer_dims, + # dropout=self.dropout_prob + # ) + # self.decoder = Tower( + # dims=self.decoder_layer_dims, + # dropout=self.dropout_prob + # ) + + self.encoder = self.make_encoding_tower(input_dim, embedding_dim) + self.decoder = self.make_encoding_tower(embedding_dim, input_dim) + + self.codebooks = torch.nn.ParameterList() + for _ in range(num_codebooks): + cb = torch.FloatTensor(codebook_size, embedding_dim) + self.codebooks.append(cb) + + @staticmethod + def make_encoding_tower(d1, d2, bias=False): + return torch.nn.Sequential( + nn.Linear(d1, d1), + nn.ReLU(), + nn.Linear(d1, d2), + nn.ReLU(), + nn.Linear(d2, d2, bias=bias) + ) + + @staticmethod + def get_codebook_indices(remainder, codebook): + dist = torch.cdist(remainder, codebook) + return dist.argmin(dim=-1) + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + clusters.append(codebook_indices) + + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = F.mse_loss(embeddings_restored, inputs['embedding']) + loss = (recon_loss + self.quant_loss_weight * rqvae_loss).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } + + +class VectorQuantizer(nn.Module): + + def __init__( + self, + codebook_size, + embedding_dim, + mu=0.25, + ): + super().__init__() + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.mu = mu + + self.embedding = nn.Embedding(self.codebook_size, self.embedding_dim) + + def get_codebook(self): + return self.embedding.weight + + def forward(self, latent_embeddings): + # Get closest centroids + d = torch.sum(latent_embeddings**2, dim=1, keepdim=True) + torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t() - 2 * torch.matmul(latent_embeddings, self.embedding.weight.t()) + indices = torch.argmin(d, dim=-1) + + x_q = self.embedding(indices) + + # compute loss for embedding + commitment_loss = F.mse_loss(x_q.detach(), latent_embeddings) + codebook_loss = F.mse_loss(x_q, latent_embeddings.detach()) + + quantization_loss = codebook_loss + self.mu * commitment_loss + + # preserve gradients + x_q = latent_embeddings + (x_q - latent_embeddings).detach() + + indices = indices.view(latent_embeddings.shape[:-1]) + + return x_q, quantization_loss, indices + + +class ResidualVectorQuantizer(nn.Module): + def __init__( + self, + num_codebooks, + codebook_size, + embedding_dim, + ): + super().__init__() + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + + self.vq_layers: list[VectorQuantizer] = nn.ModuleList([ + VectorQuantizer(codebook_size, embedding_dim) for _ in range(num_codebooks) + ]) + + def forward(self, latent_embeddings): + all_losses = [] + all_indices = [] + + x_q = 0 + residual = latent_embeddings + + for quantizer in self.vq_layers: + x_res, loss, indices = quantizer(residual) + residual = residual - x_res + x_q = x_q + x_res + + all_losses.append(loss) + all_indices.append(indices) + + mean_losses = torch.stack(all_losses).mean() + all_indices = torch.stack(all_indices, dim=-1) + + return x_q, mean_losses, all_indices + + +class NewRQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + layers, + dropout_prob=0.0, + beta=0.25, + quant_loss_weight=1.0, + cf_loss_weight=1.0, + cf_embeddings=None + ): + super().__init__() + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.beta = beta + self.quant_loss_weight = quant_loss_weight + + self.layers = layers + self.dropout_prob = dropout_prob + self.cf_embeddings = cf_embeddings + self.cf_loss_weight = cf_loss_weight + + self.encoder_layer_dims = [self.input_dim] + self.layers + [self.embedding_dim] + self.decoder_layer_dims = self.encoder_layer_dims[::-1] + + # TODO add inizialisation with AE + self.encoder = Tower( + dims=self.encoder_layer_dims, + dropout=self.dropout_prob + ) + self.decoder = Tower( + dims=self.decoder_layer_dims, + dropout=self.dropout_prob + ) + + self.rq = ResidualVectorQuantizer( + num_codebooks=num_codebooks, + codebook_size=codebook_size, + embedding_dim=embedding_dim + ) + + @staticmethod + def get_codebook_indices(remainder, quantizer): + dist = torch.sum(remainder**2, dim=1, keepdim=True) + torch.sum(quantizer.embedding.weight**2, dim=1, keepdim=True).t() - 2 * torch.matmul(remainder, quantizer.embedding.weight.t()) + return dist.argmin(dim=-1) + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + for quantizer in self.rq.vq_layers: + codebook_indices = self.get_codebook_indices(remainder, quantizer) + clusters.append(codebook_indices) + + quantized = quantizer.embedding(codebook_indices) + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + # codebook_vectors, quantizer_loss, codebook_indices = quantizer(remainder) + # rqvae_loss += quantizer_loss + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = F.mse_loss(embeddings_restored, inputs['embedding']) + + # TODO for now + # if self.cf_embeddings is not None: + # cf_embedding_in_batch = self.cf_embeddings[item_ids] + # cf_embedding_in_batch = torch.from_numpy(cf_embedding_in_batch).to(quantized_embeddings.device) + # cf_loss = self.CF_loss(quantized_embeddings, cf_embedding_in_batch) + # else: + cf_loss = torch.as_tensor(0.0) + + loss = (recon_loss + self.quant_loss_weight * rqvae_loss + self.cf_loss_weight * cf_loss).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + # loss, recon_loss, cf_loss, rq_loss = self.compute_loss( + # content_embeddings=content_embeddings, + # out_embeddings=out_embeddings, + # item_ids=item_ids, + # rq_loss=rq_loss, + # quantized_embeddings=quantized_embeddings + # ) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + 'cf_loss': cf_loss.item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } + + # def CF_loss(self, quantized_rep, encoded_rep): + # batch_size = quantized_rep.size(0) + # labels = torch.arange(batch_size, dtype=torch.long, device=quantized_rep.device) + # similarities = quantized_rep @ encoded_rep.T + # cf_loss = F.cross_entropy(similarities, labels) + # return cf_loss + + # @torch.no_grad() + # def get_indices(self, content_embeddings): + # latent_embeddings = self.encoder(content_embeddings) + # _, _, indices = self.rq(latent_embeddings) + # return indices + + # def compute_loss(self, content_embeddings, out_embeddings, item_ids, rq_loss, quantized_embeddings): + # if self.loss_type == 'mse': + # recon_loss = F.mse_loss(content_embeddings, out_embeddings, reduction='mean') + # elif self.loss_type == 'l1': + # recon_loss = F.l1_loss(content_embeddings, out_embeddings, reduction='mean') + # else: + # raise ValueError('incompatible loss type') + + # if self.cf_embeddings is not None: + # cf_embedding_in_batch = self.cf_embeddings[item_ids] + # cf_embedding_in_batch = torch.from_numpy(cf_embedding_in_batch).to(quantized_embeddings.device) + # cf_loss = self.CF_loss(quantized_embeddings, cf_embedding_in_batch) + # else: + # cf_loss = torch.as_tensor(0.0) + + # total_loss = recon_loss + self.quant_loss_weight * rq_loss + self.cf_loss_weight * cf_loss + + # return total_loss, recon_loss, cf_loss, rq_loss + + +# class BestRQVAE(nn.Module): +# def __init__( +# self, +# input_dim, +# num_codebooks, +# codebook_size, +# embedding_dim, +# beta=0.25, +# quant_loss_weight=1.0, +# ): +# super().__init__() +# self.register_buffer('beta', torch.tensor(beta)) + +# self.input_dim = input_dim +# self.num_codebooks = num_codebooks +# self.codebook_size = codebook_size +# self.embedding_dim = embedding_dim +# self.quant_loss_weight = quant_loss_weight + +# # TODO add inizialisation with AE +# self.encoder = self.make_encoding_tower(input_dim, embedding_dim) +# self.decoder = self.make_encoding_tower(embedding_dim, input_dim) + +# self.codebooks = torch.nn.ParameterList() +# for _ in range(num_codebooks): +# cb = torch.FloatTensor(codebook_size, embedding_dim) +# self.codebooks.append(cb) + +# @staticmethod +# def make_encoding_tower(d1, d2, bias=False): +# return torch.nn.Sequential( +# nn.Linear(d1, d1), +# nn.ReLU(), +# nn.Linear(d1, d2), +# nn.ReLU(), +# nn.Linear(d2, d2, bias=bias) +# ) + +# @staticmethod +# def get_codebook_indices(remainder, codebook): +# dist = torch.cdist(remainder, codebook) +# return dist.argmin(dim=-1) + +# def forward(self, inputs): +# latent_vector = self.encoder(inputs['embedding']) + +# latent_restored = 0 +# rqvae_loss = 0 +# clusters = [] +# remainder = latent_vector +# for codebook in self.codebooks: +# codebook_indices = self.get_codebook_indices(remainder, codebook) +# clusters.append(codebook_indices) + +# quantized = codebook[codebook_indices] +# codebook_vectors = remainder + (quantized - remainder).detach() + +# rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) +# rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + +# latent_restored += codebook_vectors +# remainder = remainder - codebook_vectors + +# embeddings_restored = self.decoder(latent_restored) +# recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding']) +# loss = (recon_loss + rqvae_loss).mean() + +# clusters_counts = [] +# for cluster in clusters: +# clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + +# return loss, { +# 'loss': loss.item(), +# 'recon_loss': recon_loss.mean().item(), +# 'rqvae_loss': rqvae_loss.mean().item(), + +# 'clusters_counts': clusters_counts, +# 'clusters': torch.stack(clusters).T, +# 'embedding_hat': embeddings_restored, +# } + + +class BestRQVAE(nn.Module): + def __init__( + self, + input_dim, + num_codebooks, + codebook_size, + embedding_dim, + beta=0.25, + quant_loss_weight=1.0, + cf_loss_weight=1.0, + cf_embeddings=None + ): + super().__init__() + self.register_buffer('beta', torch.tensor(beta)) + + self.input_dim = input_dim + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.embedding_dim = embedding_dim + self.quant_loss_weight = quant_loss_weight + + self.cf_loss_weight = cf_loss_weight + if cf_embeddings is not None: + self.register_buffer('cf_embeddings', torch.tensor(cf_embeddings)) + else: + self.cf_embeddings = None + + # TODO add inizialisation with AE + self.encoder = self.make_encoding_tower(input_dim, embedding_dim) + self.decoder = self.make_encoding_tower(embedding_dim, input_dim) + + self.codebooks = torch.nn.ParameterList() + for _ in range(num_codebooks): + cb = torch.FloatTensor(codebook_size, embedding_dim) + self.codebooks.append(cb) + + @staticmethod + def make_encoding_tower(d1, d2, bias=False): + return torch.nn.Sequential( + nn.Linear(d1, d1), + nn.ReLU(), + nn.Linear(d1, d2), + nn.ReLU(), + nn.Linear(d2, d2, bias=bias) + ) + + @staticmethod + def get_codebook_indices(remainder, codebook): + dist = torch.cdist(remainder, codebook) + return dist.argmin(dim=-1) + + def forward(self, inputs): + latent_vector = self.encoder(inputs['embedding']) + item_ids = inputs['item_id'] + + latent_restored = 0 + rqvae_loss = 0 + clusters = [] + remainder = latent_vector + for codebook in self.codebooks: + codebook_indices = self.get_codebook_indices(remainder, codebook) + clusters.append(codebook_indices) + + quantized = codebook[codebook_indices] + codebook_vectors = remainder + (quantized - remainder).detach() + + rqvae_loss += self.beta * torch.nn.functional.mse_loss(remainder, quantized.detach()) + rqvae_loss += torch.nn.functional.mse_loss(quantized, remainder.detach()) + + latent_restored += codebook_vectors + remainder = remainder - codebook_vectors + + embeddings_restored = self.decoder(latent_restored) + recon_loss = torch.nn.functional.mse_loss(embeddings_restored, inputs['embedding']) + + if self.cf_embeddings is not None: + cf_embedding_in_batch = self.cf_embeddings[item_ids] + cf_loss = self.CF_loss(latent_restored, cf_embedding_in_batch) + else: + cf_loss = torch.as_tensor(0.0) + + loss = (recon_loss + self.quant_loss_weight * rqvae_loss + self.cf_loss_weight * cf_loss).mean() + + clusters_counts = [] + for cluster in clusters: + clusters_counts.append(torch.bincount(cluster, minlength=self.codebook_size)) + + return loss, { + 'loss': loss.item(), + 'recon_loss': recon_loss.mean().item(), + 'rqvae_loss': rqvae_loss.mean().item(), + 'cf_loss': cf_loss.item(), + + 'clusters_counts': clusters_counts, + 'clusters': torch.stack(clusters).T, + 'embedding_hat': embeddings_restored, + } + + def CF_loss(self, quantized_rep, encoded_rep): + batch_size = quantized_rep.size(0) + labels = torch.arange(batch_size, dtype=torch.long, device=quantized_rep.device) + similarities = quantized_rep @ encoded_rep.T + cf_loss = F.cross_entropy(similarities, labels) + return cf_loss diff --git a/scripts/rqvae/train_best.py b/scripts/rqvae/train_best.py new file mode 100644 index 00000000..d6aad7df --- /dev/null +++ b/scripts/rqvae/train_best.py @@ -0,0 +1,149 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDataset, ProcessEmbeddings +from models import BestRQVAE + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 500 +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 + +EXPERIMENT_NAME = 'rqvae_beauty_0' +IREC_PATH = '../../' + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDataset( + data_path=os.path.join(IREC_PATH, 'data/Beauty/content_embeddings_0_part.pkl') + ) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(ProcessEmbeddings(embedding_dim=INPUT_DIM, keys=['embedding'])) + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + # cf_embeddings = torch.load( + # os.path.join(IREC_PATH, 'results', 'sasrec_beauty_8_1_1_32d_item_embeddings.pt'), + # map_location='cpu' + # )['weight'] + + model = BestRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + # cf_loss_weight=1.0, + # cf_embeddings=cf_embeddings + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'cf_loss': model_outputs['cf_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/cf_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'cf_loss': model_outputs['cf_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/cf_loss': cb.MeanAccumulator(), + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/rqvae/train_best_pairs.py b/scripts/rqvae/train_best_pairs.py new file mode 100644 index 00000000..d4be6464 --- /dev/null +++ b/scripts/rqvae/train_best_pairs.py @@ -0,0 +1,156 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDataset, process_embeddings +from models import OldRQVAE, NewRQVAE, BestRQVAE + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 200 +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 + +EXPERIMENT_NAME = 'rqvae_beauty_best_cf_signal_1.0' +IREC_PATH = '../../' + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDataset( + data_path=os.path.join(IREC_PATH, 'data/Beauty/content_embeddings.pkl') + ) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(process_embeddings).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(process_embeddings) + + # index_dataloader = DataLoader( + # dataset, + # batch_size=len(dataset), + # shuffle=False, + # drop_last=False, + # ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(process_embeddings) + + cf_embeddings = torch.load( + os.path.join(IREC_PATH, 'results', 'sasrec_item_embeddings_32d.pt'), + map_location='cpu' + )['weight'] + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = BestRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + beta=BETA, + quant_loss_weight=1.0, + cf_loss_weight=1.0, + cf_embeddings=cf_embeddings + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'cf_loss': model_outputs['cf_loss'] + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + 'train/cf_loss': cb.MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + 'cf_loss': model_outputs['cf_loss'] + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + 'valid/cf_loss': cb.MeanAccumulator(), + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/rqvae/train_new.py b/scripts/rqvae/train_new.py new file mode 100644 index 00000000..13fea9d3 --- /dev/null +++ b/scripts/rqvae/train_new.py @@ -0,0 +1,218 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner, TrainingRunnerContext + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDataset, process_embeddings +from models import OldRQVAE, NewRQVAE, BestRQVAE + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 200 +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 + +EXPERIMENT_NAME = 'rqvae_beauty_new' +IREC_PATH = '../../' + + + +class InitCodebooks(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + @torch.no_grad() + def before_run(self, runner: TrainingRunner): + for i in range(len(runner.model.rq.vq_layers)): + X = next(iter(self._dataloader))['embedding'] + idx = torch.randperm(X.shape[0], device=X.device)[:runner.model.codebook_size] + remainder = runner.model.encoder(X[idx]) + + for j in range(i): + codebook_indices = runner.model.get_codebook_indices(remainder, runner.model.rq.vq_layers[j]) + codebook_vectors = runner.model.rq.vq_layers[j].embedding(codebook_indices) + remainder = remainder - codebook_vectors + + runner.model.rq.vq_layers[i].embedding.weight.data = remainder.detach() + + +class FixDeadCentroids(cb.TrainingCallback): + def __init__(self, dataloader): + super().__init__() + self._dataloader = dataloader + + def after_step(self, runner: TrainingRunner, context: TrainingRunnerContext): + for i, num_fixed in enumerate(self.fix_dead_codebooks(runner)): + context.metrics[f'num_dead/{i}'] = num_fixed + + @torch.no_grad() + def fix_dead_codebooks(self, runner: TrainingRunner): + num_fixed = [] + for codebook_idx, quantizer in enumerate(runner.model.rq.vq_layers): + centroid_counts = torch.zeros(quantizer.codebook_size, dtype=torch.long, device=DEVICE) + random_batch = next(iter(self._dataloader))['embedding'] + + for batch in self._dataloader: + remainder = runner.model.encoder(batch['embedding']) + for l in range(codebook_idx): + _, _, ind = runner.model.rq.vq_layers[l](remainder) + remainder = remainder - runner.model.rq.vq_layers[l].embedding(ind) + + indices = runner.model.get_codebook_indices(remainder, quantizer) + centroid_counts.scatter_add_(0, indices, torch.ones_like(indices)) + + dead_mask = (centroid_counts == 0) + num_dead = int(dead_mask.sum().item()) + num_fixed.append(num_dead) + if num_dead == 0: + continue + + remainder = runner.model.encoder(random_batch) + for l in range(codebook_idx): + ind = runner.model.get_codebook_indices(remainder, runner.model.rq.vq_layers[l]) + remainder = remainder - runner.model.rq.vq_layers[l].embedding(ind) + remainder = remainder[torch.randperm(remainder.shape[0], device=remainder.device)][:num_dead] + runner.model.rq.vq_layers[codebook_idx].embedding.weight.data[dead_mask] = remainder.detach() + + return num_fixed + + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDataset( + data_path=os.path.join(IREC_PATH, 'data/Beauty/content_embeddings.pkl') + ) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(process_embeddings).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(process_embeddings) + + # index_dataloader = DataLoader( + # dataset, + # batch_size=len(dataset), + # shuffle=False, + # drop_last=False, + # ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(process_embeddings) + + # cf_embedding_path = '../data/Beauty/collaborative_item_embeddings.pt' + # if cf_embedding_path is not None: + # cf_embeddings = torch.load(cf_embedding_path).squeeze().detach().numpy() + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = NewRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + layers=[2048, 1024, 512, 256, 128], + dropout_prob=0.1, + beta=BETA, + quant_loss_weight=1.0, + # cf_loss_weight=0.0, + # cf_embeddings=cf_embeddings + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + # 'train/cf_loss': MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + # 'valid/cf_loss': MeanAccumulator(), + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')), + + cb.EarlyStopping( + metric='valid/recon_loss', + patience=40, + minimize=True, + model_path=os.path.join(IREC_PATH, 'checkpoints', EXPERIMENT_NAME) + ).every_num_steps(LOG_EVERY_NUM_STEPS), + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() diff --git a/scripts/rqvae/train_old.py b/scripts/rqvae/train_old.py new file mode 100644 index 00000000..65950781 --- /dev/null +++ b/scripts/rqvae/train_old.py @@ -0,0 +1,149 @@ +from loguru import logger +import os + +import torch + +import irec.callbacks as cb +from irec.data.dataloader import DataLoader +from irec.data.transforms import Collate, ToTorch, ToDevice +from irec.runners import TrainingRunner + +from irec.utils import fix_random_seed + +from callbacks import InitCodebooks, FixDeadCentroids +from data import EmbeddingDataset, process_embeddings +from models import OldRQVAE, NewRQVAE, BestRQVAE + +SEED_VALUE = 42 +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + +NUM_EPOCHS = 200 +BATCH_SIZE = 1024 + +INPUT_DIM = 4096 +HIDDEN_DIM = 32 +CODEBOOK_SIZE = 256 +NUM_CODEBOOKS = 3 +BETA = 0.25 +LR = 1e-4 + +EXPERIMENT_NAME = 'rqvae_beauty_old' +IREC_PATH = '../../' + + + +def main(): + fix_random_seed(SEED_VALUE) + + dataset = EmbeddingDataset( + data_path=os.path.join(IREC_PATH, 'data/Beauty/content_embeddings.pkl') + ) + + train_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(process_embeddings).repeat(NUM_EPOCHS) + + valid_dataloader = DataLoader( + dataset, + batch_size=BATCH_SIZE, + shuffle=False, + drop_last=False, + ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(process_embeddings) + + # index_dataloader = DataLoader( + # dataset, + # batch_size=len(dataset), + # shuffle=False, + # drop_last=False, + # ).map(Collate()).map(ToTorch()).map(ToDevice(DEVICE)).map(process_embeddings) + + # cf_embedding_path = '../data/Beauty/collaborative_item_embeddings.pt' + # if cf_embedding_path is not None: + # cf_embeddings = torch.load(cf_embedding_path).squeeze().detach().numpy() + + LOG_EVERY_NUM_STEPS = int(len(train_dataloader) // NUM_EPOCHS) + + model = OldRQVAE( + input_dim=INPUT_DIM, + num_codebooks=NUM_CODEBOOKS, + codebook_size=CODEBOOK_SIZE, + embedding_dim=HIDDEN_DIM, + layers=[2048, 1024, 512, 256, 128], + dropout_prob=0.1, + beta=BETA, + quant_loss_weight=1.0, + # cf_loss_weight=0.0, + # cf_embeddings=cf_embeddings + ).to(DEVICE) + + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + logger.debug(f'Overall parameters: {total_params:,}') + logger.debug(f'Trainable parameters: {trainable_params:,}') + + optimizer = torch.optim.Adam(model.parameters(), lr=LR, fused=True) + + callbacks = [ + InitCodebooks(valid_dataloader), + + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + }, name='train'), + + FixDeadCentroids(valid_dataloader), + + cb.MetricAccumulator( + accumulators={ + 'train/loss': cb.MeanAccumulator(), + 'train/recon_loss': cb.MeanAccumulator(), + 'train/rqvae_loss': cb.MeanAccumulator(), + # 'train/cf_loss': MeanAccumulator(), + 'num_dead/0': cb.MeanAccumulator(), + 'num_dead/1': cb.MeanAccumulator(), + 'num_dead/2': cb.MeanAccumulator(), + }, + reset_every_num_steps=LOG_EVERY_NUM_STEPS + ), + + cb.Validation( + dataset=valid_dataloader, + callbacks=[ + cb.BatchMetrics(metrics=lambda model_outputs, batch: { + 'loss': model_outputs['loss'], + 'recon_loss': model_outputs['recon_loss'], + 'rqvae_loss': model_outputs['rqvae_loss'], + }, name='valid'), + cb.MetricAccumulator( + accumulators={ + 'valid/loss': cb.MeanAccumulator(), + 'valid/recon_loss': cb.MeanAccumulator(), + 'valid/rqvae_loss': cb.MeanAccumulator(), + # 'valid/cf_loss': MeanAccumulator(), + } + ), + ], + ).every_num_steps(LOG_EVERY_NUM_STEPS), + + cb.Logger().every_num_steps(LOG_EVERY_NUM_STEPS), + cb.TensorboardLogger(experiment_name=EXPERIMENT_NAME, logdir=os.path.join(IREC_PATH, 'tensorboard_logs')) + ] + + logger.debug('Everything is ready for training process!') + + runner = TrainingRunner( + model=model, + optimizer=optimizer, + dataset=train_dataloader, + callbacks=callbacks, + ) + runner.run() + + +if __name__ == '__main__': + main() From f108f2ed77898be8b725459ef3980766c9136fde Mon Sep 17 00:00:00 2001 From: Iskander Bagautdinov <112892889+iskbaga@users.noreply.github.com> Date: Mon, 17 Nov 2025 20:36:27 +0300 Subject: [PATCH 5/5] tiger extract validation loss from tensor --- scripts/tiger/train.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/scripts/tiger/train.py b/scripts/tiger/train.py index 3cd6fd99..f436dd40 100644 --- a/scripts/tiger/train.py +++ b/scripts/tiger/train.py @@ -129,7 +129,15 @@ def main(): cb.Validation( dataset=valid_dataloder, callbacks=[ - cb.BatchMetrics(metrics=lambda model_outputs, _: model_outputs, name='validation'), + cb.BatchMetrics(metrics=lambda model_outputs, _:{ + 'loss': model_outputs['loss'].item(), + 'recall@5': model_outputs['recall@5'].tolist(), + 'recall@10': model_outputs['recall@10'].tolist(), + 'recall@20': model_outputs['recall@20'].tolist(), + 'ndcg@5': model_outputs['ndcg@5'].tolist(), + 'ndcg@10': model_outputs['ndcg@10'].tolist(), + 'ndcg@20': model_outputs['ndcg@20'].tolist(), + }, name='validation'), cb.MetricAccumulator( accumulators={ 'validation/loss': cb.MeanAccumulator(),