forked from pmixer/SASRec.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
148 lines (123 loc) · 6.95 KB
/
main.py
File metadata and controls
148 lines (123 loc) · 6.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import time
import torch
import argparse
from model import SASRec
from utils import *
def str2bool(s):
if s not in {'false', 'true'}:
raise ValueError('Not a valid boolean string')
return s == 'true'
# 如果字符串s为‘true’则返回true,如果为‘false’返回false,否则报错
parser = argparse.ArgumentParser()
#创建一个解析对象
# 定义命令行参数
parser.add_argument('--dataset', required=True) #必要参数,必须在命令行中指定
parser.add_argument('--train_dir', required=True) #必要参数
parser.add_argument('--batch_size', default=128, type=int) #默认值为128
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--maxlen', default=50, type=int)
parser.add_argument('--hidden_units', default=50, type=int)
parser.add_argument('--num_blocks', default=2, type=int)
parser.add_argument('--num_epochs', default=201, type=int)
parser.add_argument('--num_heads', default=1, type=int)
parser.add_argument('--dropout_rate', default=0.5, type=float)
parser.add_argument('--l2_emb', default=0.0, type=float)
parser.add_argument('--device', default='cpu', type=str)
parser.add_argument('--inference_only', default=False, type=str2bool)
parser.add_argument('--state_dict_path', default=None, type=str)
args = parser.parse_args()
#调用parser_args()对命令行参数解析
if not os.path.isdir(args.dataset + '_' + args.train_dir):
os.makedirs(args.dataset + '_' + args.train_dir)
#如果当前模型不存在,则创建该路径,并在该路径下训练模型
with open(os.path.join(args.dataset + '_' + args.train_dir, 'args.txt'), 'w') as f:
f.write('\n'.join([str(k) + ',' + str(v) for k, v in sorted(vars(args).items(), key=lambda x: x[0])]))
#将参数内容写入到该模型路径下的args.txt中,每一行是形式参数和实际参数
#f.close()
#with 语句会在代码块执行完毕后自动关闭文件,无论代码块中发生了什么异常。
if __name__ == '__main__':
# global dataset
dataset = data_partition(args.dataset)
#数据集划分,data_partition在utils中
[user_train, user_valid, user_test, usernum, itemnum] = dataset
num_batch = len(user_train) // args.batch_size # tail? + ((len(user_train) % args.batch_size) != 0)
# 计算训练批次,batch是一次训练时用到的样本数,分成多批次训练,引入一定随机性,user_train是总的样本数
cc = 0.0
for u in user_train:
#cc:统计训练集item的总长度
cc += len(user_train[u])
print('average sequence length: %.2f' % (cc / len(user_train)))
#打印序列的平均长度
f = open(os.path.join(args.dataset + '_' + args.train_dir, 'log.txt'), 'w')
# 打开日志文件
sampler = WarpSampler(user_train, usernum, itemnum, batch_size=args.batch_size, maxlen=args.maxlen, n_workers=3)
#实例化WarpSample类,通过采样用户生成数据
model = SASRec(usernum, itemnum, args).to(args.device) # no ReLU activation in original SASRec implementation?
#模型类实例化
for name, param in model.named_parameters():
try:
torch.nn.init.xavier_normal_(param.data)
except:
pass # just ignore those failed init layers
# this fails embedding init 'Embedding' object has no attribute 'dim'
# model.apply(torch.nn.init.xavier_uniform_)
model.train() # enable model training
epoch_start_idx = 1
if args.state_dict_path is not None:
try:
model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6:]
epoch_start_idx = int(tail[:tail.find('.')]) + 1
except: # in case your pytorch version is not 1.6 etc., pls debug by pdb if load weights failed
print('failed loading state_dicts, pls check file path: ', end="")
print(args.state_dict_path)
print('pdb enabled for your quick check, pls type exit() if you do not need it')
import pdb; pdb.set_trace()
if args.inference_only:
model.eval()
t_test = evaluate(model, dataset, args)
print('test (NDCG@10: %.4f, HR@10: %.4f)' % (t_test[0], t_test[1]))
# ce_criterion = torch.nn.CrossEntropyLoss()
# https://github.com/NVIDIA/pix2pixHD/issues/9 how could an old bug appear again...
bce_criterion = torch.nn.BCEWithLogitsLoss() # torch.nn.BCELoss()
adam_optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
T = 0.0
t0 = time.time()
for epoch in range(epoch_start_idx, args.num_epochs + 1):
if args.inference_only: break # just to decrease identition
for step in range(num_batch): # tqdm(range(num_batch), total=num_batch, ncols=70, leave=False, unit='b'):
u, seq, pos, neg = sampler.next_batch() # tuples to ndarray
u, seq, pos, neg = np.array(u), np.array(seq), np.array(pos), np.array(neg)
pos_logits, neg_logits = model(u, seq, pos, neg)
pos_labels, neg_labels = torch.ones(pos_logits.shape, device=args.device), torch.zeros(neg_logits.shape, device=args.device)
# print("\neye ball check raw_logits:"); print(pos_logits); print(neg_logits) # check pos_logits > 0, neg_logits < 0
adam_optimizer.zero_grad()
indices = np.where(pos != 0)
loss = bce_criterion(pos_logits[indices], pos_labels[indices])
loss += bce_criterion(neg_logits[indices], neg_labels[indices])
for param in model.item_emb.parameters(): loss += args.l2_emb * torch.norm(param)
loss.backward()
adam_optimizer.step()
print("loss in epoch {} iteration {}: {}".format(epoch, step, loss.item())) # expected 0.4~0.6 after init few epochs
if epoch % 20 == 0:
model.eval()
t1 = time.time() - t0
T += t1
print('Evaluating', end='')
t_test = evaluate(model, dataset, args)
t_valid = evaluate_valid(model, dataset, args)
print('epoch:%d, time: %f(s), valid (NDCG@10: %.4f, HR@10: %.4f), test (NDCG@10: %.4f, HR@10: %.4f)'
% (epoch, T, t_valid[0], t_valid[1], t_test[0], t_test[1]))
f.write(str(t_valid) + ' ' + str(t_test) + '\n')
f.flush()
t0 = time.time()
model.train()
if epoch == args.num_epochs:
folder = args.dataset + '_' + args.train_dir
fname = 'SASRec.epoch={}.lr={}.layer={}.head={}.hidden={}.maxlen={}.pth'
fname = fname.format(args.num_epochs, args.lr, args.num_blocks, args.num_heads, args.hidden_units, args.maxlen)
torch.save(model.state_dict(), os.path.join(folder, fname))
f.close()
sampler.close()
print("Done")