Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .idea/DeepRec.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions .idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

327 changes: 109 additions & 218 deletions .idea/workspace.xml

Large diffs are not rendered by default.

87 changes: 56 additions & 31 deletions models/item_ranking/bprmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
the twenty-fifth conference on uncertainty in artificial intelligence. AUAI Press, 2009..
"""

import tensorflow as tf
import time
import numpy as np

import tensorflow as tf
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Dense, Lambda

from utils.evaluation.RankingMetrics import *

Expand All @@ -20,10 +22,18 @@
__status__ = "Development"


class EmbeddingLookup(tf.keras.layers.Layer):
def __init__(self, input_embedding, **kwargs):
super(EmbeddingLookup, self).__init__(**kwargs)
self.input_embedding = input_embedding

def call(self, inputs):
return tf.nn.embedding_lookup(self.input_embedding, inputs)


class BPRMF(object):
def __init__(self, sess, num_user, num_item, learning_rate=0.001, reg_rate=0.1, epoch=500, batch_size=1024,
def __init__(self, num_user, num_item, learning_rate=0.001, reg_rate=0.1, epoch=500, batch_size=1024,
verbose=False, t=5, display_step=1000):
self.sess = sess
self.num_user = num_user
self.num_item = num_item
self.learning_rate = learning_rate
Expand All @@ -37,13 +47,14 @@ def __init__(self, sess, num_user, num_item, learning_rate=0.001, reg_rate=0.1,
self.user_id = None
self.item_id = None
self.neg_item_id = None
self.y = None
self.P = None
self.Q = None
self.pred_y = None
self.pred_y_neg = None
self.loss = None
self.optimizer = None
self.loss_estimator = tf.keras.metrics.Mean(name='train_loss')
self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
self.model = None

self.test_data = None
self.user = None
Expand All @@ -56,26 +67,25 @@ def __init__(self, sess, num_user, num_item, learning_rate=0.001, reg_rate=0.1,
print("You are running BPRMF.")

def build_network(self, num_factor=30):
self.user_id = tf.placeholder(dtype=tf.int32, shape=[None], name='user_id')
self.item_id = tf.placeholder(dtype=tf.int32, shape=[None], name='item_id')
self.neg_item_id = tf.placeholder(dtype=tf.int32, shape=[None], name='neg_item_id')
self.y = tf.placeholder("float", [None], 'rating')

self.P = tf.Variable(tf.random_normal([self.num_user, num_factor], stddev=0.01))
self.Q = tf.Variable(tf.random_normal([self.num_item, num_factor], stddev=0.01))
self.user_id = Input(shape=(1,), dtype=tf.int32, name='user_id')
self.item_id = Input(shape=(1,), dtype=tf.int32, name='item_id')
self.neg_item_id = Input(shape=(1,), dtype=tf.int32, name='neg_item_id')

user_latent_factor = tf.nn.embedding_lookup(self.P, self.user_id)
item_latent_factor = tf.nn.embedding_lookup(self.Q, self.item_id)
neg_item_latent_factor = tf.nn.embedding_lookup(self.Q, self.neg_item_id)
self.P = tf.Variable(tf.random.normal([self.num_user, num_factor], stddev=0.01))
self.Q = tf.Variable(tf.random.normal([self.num_item, num_factor], stddev=0.01))

self.pred_y = tf.reduce_sum(tf.multiply(user_latent_factor, item_latent_factor), 1)
self.pred_y_neg = tf.reduce_sum(tf.multiply(user_latent_factor, neg_item_latent_factor), 1)
user_id = Lambda(lambda x: tf.squeeze(x))(self.user_id)
item_id = Lambda(lambda x: tf.squeeze(x))(self.item_id)
neg_item_id = Lambda(lambda x: tf.squeeze(x))(self.neg_item_id)

self.loss = - tf.reduce_sum(
tf.log(tf.sigmoid(self.pred_y - self.pred_y_neg))) +\
self.reg_rate * (tf.nn.l2_loss(self.P) + tf.nn.l2_loss(self.Q))
user_latent_factor = EmbeddingLookup(self.P)(user_id)
item_latent_factor = EmbeddingLookup(self.Q)(item_id)
neg_item_latent_factor = EmbeddingLookup(self.Q)(neg_item_id)

self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
pred_y = tf.reduce_sum(tf.multiply(user_latent_factor, item_latent_factor), 1)
pred_y_neg = tf.reduce_sum(tf.multiply(user_latent_factor, neg_item_latent_factor), 1)
self.model = Model(inputs=[self.user_id, self.item_id, self.neg_item_id],
outputs=[pred_y, pred_y_neg])

def prepare_data(self, train_data, test_data):
"""
Expand All @@ -94,6 +104,18 @@ def prepare_data(self, train_data, test_data):
self.test_users = set([u for u in self.test_data.keys() if len(self.test_data[u]) > 0])
print("data preparation finished.")

@tf.function
def train_op(self, batch_user, batch_item, batch_item_neg):
with tf.GradientTape() as tape:
pred_y, pred_y_neg = self.model([batch_user, batch_item, batch_item_neg])
#TODO: use tf loss object
loss = - tf.reduce_sum(
tf.math.log(tf.sigmoid(pred_y - pred_y_neg))) + \
self.reg_rate * (tf.nn.l2_loss(self.P) + tf.nn.l2_loss(self.Q))
gradients_of_model = tape.gradient(loss, self.model.trainable_variables)
self.optimizer.apply_gradients(zip(gradients_of_model, self.model.trainable_variables))
self.loss_estimator(loss)

def train(self):
idxs = np.random.permutation(self.num_training) # shuffled ordering
user_random = list(self.user[idxs])
Expand All @@ -107,17 +129,19 @@ def train(self):
# train
for i in range(self.total_batch):
start_time = time.time()
batch_user = user_random[i * self.batch_size:(i + 1) * self.batch_size]
batch_user = np.array(user_random[i * self.batch_size:(i + 1) * self.batch_size])
batch_item = item_random[i * self.batch_size:(i + 1) * self.batch_size]
batch_item_neg = item_random_neg[i * self.batch_size:(i + 1) * self.batch_size]

_, loss = self.sess.run((self.optimizer, self.loss), feed_dict={self.user_id: batch_user,
self.item_id: batch_item,
self.neg_item_id: batch_item_neg})
np_batch_user = np.expand_dims(np.array(batch_user), -1)
np_batch_item = np.expand_dims(np.array(batch_item), -1)
np_batch_item_neg = np.expand_dims(np.array(batch_item_neg), -1)

self.train_op(np_batch_user, np_batch_item, np_batch_item_neg)

if i % self.display_step == 0:
if self.verbose:
print("Index: %04d; cost= %.9f" % (i + 1, np.mean(loss)))
print("Index: %04d; cost= %.9f" % (i + 1, self.loss_estimator.result()))
print("one iteration: %s seconds." % (time.time() - start_time))

def test(self):
Expand All @@ -126,9 +150,6 @@ def test(self):
def execute(self, train_data, test_data):
self.prepare_data(train_data, test_data)

init = tf.global_variables_initializer()
self.sess.run(init)

for epoch in range(self.epochs):
self.train()
if epoch % self.T == 0:
Expand All @@ -140,7 +161,11 @@ def save(self, path):
saver.save(self.sess, path)

def predict(self, user_id, item_id):
return self.sess.run([self.pred_y], feed_dict={self.user_id: user_id, self.item_id: item_id})[0]
user_id = tf.expand_dims(tf.convert_to_tensor(user_id), -1)
item_id = tf.expand_dims(tf.convert_to_tensor(item_id), -1)
dummy_neg_id = tf.zeros(item_id.shape, tf.int32)
pred_y, pred_y_neg = self.model([user_id, item_id, dummy_neg_id])
return pred_y.numpy()

def _get_neg_items(self, data):
all_items = set(np.arange(self.num_item))
Expand Down
Loading