From cc0845469548de174f9248499b2756125030f114 Mon Sep 17 00:00:00 2001 From: Yaroslav Zhurba Date: Sun, 22 Dec 2024 01:14:31 +0700 Subject: [PATCH] bert4rec original, timeline split, ce loss all rank --- configs/train/bert4rec_train_config.json | 10 +++++----- modeling/models/bert4rec.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/configs/train/bert4rec_train_config.json b/configs/train/bert4rec_train_config.json index 787c27c3..91418aa1 100644 --- a/configs/train/bert4rec_train_config.json +++ b/configs/train/bert4rec_train_config.json @@ -1,9 +1,9 @@ { - "experiment_name": "bert4rec_beauty", + "experiment_name": "bert4rec_beauty_dataset_bert_ce_loss", "best_metric": "eval/ndcg@20", "dataset": { - "type": "scientific", - "path_to_data_dir": "../data", + "type": "sequence", + "path_to_data_dir": "../data/sasrec_in_batch", "name": "Beauty", "max_sequence_length": 50, "samplers": { @@ -43,7 +43,7 @@ "num_heads": 2, "num_layers": 2, "dim_feedforward": 256, - "dropout": 0.2, + "dropout": 0.3, "activation": "gelu", "layer_norm_eps": 1e-9, "initializer_range": 0.02 @@ -52,7 +52,7 @@ "type": "basic", "optimizer": { "type": "adam", - "lr": 1e-4 + "lr": 0.001 }, "clip_grad_threshold": 5.0 }, diff --git a/modeling/models/bert4rec.py b/modeling/models/bert4rec.py index 40f1d331..baf4f774 100644 --- a/modeling/models/bert4rec.py +++ b/modeling/models/bert4rec.py @@ -72,11 +72,11 @@ def forward(self, inputs): ) # (batch_size, seq_len, embedding_dim), (batch_size, seq_len) embeddings = self._output_projection(embeddings) # (batch_size, seq_len, embedding_dim) - embeddings = torch.nn.functional.gelu(embeddings) # (batch_size, seq_len, embedding_dim) + # embeddings = torch.nn.functional.gelu(embeddings) # (batch_size, seq_len, embedding_dim) embeddings = torch.einsum( 'bsd,nd->bsn', embeddings, self._item_embeddings.weight ) # (batch_size, seq_len, num_items) - embeddings += self._bias[None, None, :] # (batch_size, seq_len, num_items) + # embeddings += self._bias[None, None, :] # (batch_size, seq_len, num_items) if self.training: # training mode all_sample_labels = inputs['{}.ids'.format(self._labels_prefix)] # (all_batch_events)