forked from exe1023/FusionNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
122 lines (110 loc) · 6.09 KB
/
train.py
File metadata and controls
122 lines (110 loc) · 6.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import data
from module import FusionNet, decode
from metrics import batch_score
from embedding import load_embedding
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch
from torch import optim
use_cuda = torch.cuda.is_available()
parser = argparse.ArgumentParser(description='Model and Training parameters')
# Model Architecture
parser.add_argument('--rnn_type', type=str, default='lstm', help='the rnn cell used')
parser.add_argument('--hidden_size', type=int, default=256, help='the hidden size of RNNs [256]')
parser.add_argument('--embedding_size', type=int, default=200, help='the embedding size [200]')
parser.add_argument('--vocab_size', type=int, default=25000, help='the vocab size [25000]')
# Training hyperparameter
parser.add_argument('--word_base', action='store_true')
parser.add_argument('--dropout', type=float, default=0.6)
parser.add_argument('--rnn_layer', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=32, help='the size of batch [32]')
parser.add_argument('--lr', type=float, default=1e-3, help='the learning rate of encoder [1e-3]')
parser.add_argument('--valid_ratio', type=float, default=0.05)
parser.add_argument('--valid_iters', type=int, default=1, help='run validation batch every N epochs [1]')
parser.add_argument('--max_sent', type=int, default=25, help='max length of encoder, decoder')
parser.add_argument('--display_freq', type=int, default=10, help='display training status every N iters [10]')
parser.add_argument('--save_freq', type=int, default=1, help='save model every N epochs [1]')
parser.add_argument('--epoch', type=int, default=25, help='train for N epochs [25]')
parser.add_argument('--init_embedding', action='store_true', help='whether init embedding')
parser.add_argument('--embedding_source', type=str, default='./', help='pretrained embedding path')
args = parser.parse_args()
if __name__ == '__main__':
train = data.load_data('train.json', args.word_base)
test = data.load_data('test.json', args.word_base)
# train = data.load_data('train_squad.json', args.word_base)
# test = data.load_data('dev_squad.json', args.word_base)
vocabulary, pad_lens = data.build_vocab(train, test, args.vocab_size)
print('Vocab size: %d | Max context: %d | Max question: %d'%(
len(vocabulary), pad_lens[0], pad_lens[1]))
train, valid = data.split_exp(train, args.valid_ratio)
print('Train: %d | Valid: %d | Test: %d'%(len(train), len(valid), len(test)))
train_engine = DataLoader(data.DataEngine(train, vocabulary, pad_lens),
batch_size=args.batch_size,
shuffle=True,
num_workers=8,
pin_memory=use_cuda)
valid_engine = DataLoader(data.DataEngine(valid, vocabulary, pad_lens),
batch_size=args.batch_size,
shuffle=True,
num_workers=8,
pin_memory=use_cuda)
test_engine = data.DataEngine(test, vocabulary, pad_lens)
if args.init_embedding:
w2v = load_embedding(args.embedding_source,
vocabulary.to_idx,
300)
else:
w2v = None
fusion_net = FusionNet(vocab_size=len(vocabulary),
word_dim=300,
hidden_size=125,
rnn_layer=args.rnn_layer,
dropout=args.dropout,
pretrained_embedding=w2v)
if use_cuda:
fusion_net = fusion_net.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adamax(fusion_net.parameters())
for epoch in range(args.epoch):
batch = 0
fusion_net.train()
for context, q, ans_offset, appear in train_engine:
context = Variable(context).cuda() if use_cuda else Variable(context)
q = Variable(q).cuda() if use_cuda else Variable(q)
start_ans = Variable(ans_offset[:, 0]).cuda() if use_cuda else Variable(ans_offset[:, 0])
end_ans = Variable(ans_offset[:, 1]).cuda() if use_cuda else Variable(ans_offset[:, 1])
appear = Variable(appear).cuda() if use_cuda else Variable(appear)
start, end, start_attn, end_attn = fusion_net(context, q, appear)
loss = criterion(start_attn, start_ans) + criterion(end_attn, end_ans)
loss.backward()
nn.utils.clip_grad_norm(fusion_net.parameters(), 10)
optimizer.step()
start, end, scores = decode(start.data.cpu(), end.data.cpu(), 1)
f1_score, exact_match_score = batch_score(start, end, ans_offset)
if batch % args.display_freq == 0:
print('epoch: %d | batch: %d/%d| loss: %f | f1: %f | exact: %f'%(
epoch, batch, len(train_engine), loss.data[0],
f1_score, exact_match_score
))
batch +=1
valid_f1, valid_exact = 0, 0
fusion_net.eval()
for context, q, ans_offset, appear in valid_engine:
context = Variable(context).cuda() if use_cuda else Variable(context)
q = Variable(q).cuda() if use_cuda else Variable(q)
start_ans = Variable(ans_offset[:, 0]).cuda() if use_cuda else Variable(ans_offset[:, 0])
end_ans = Variable(ans_offset[:, 1]).cuda() if use_cuda else Variable(ans_offset[:, 1])
appear = Variable(appear).cuda() if use_cuda else Variable(appear)
start, end, start_attn, end_attn = fusion_net(context, q, appear)
start, end, scores = decode(start.data.cpu(), end.data.cpu(), 1)
f1_score, exact_match_score = batch_score(start, end, ans_offset)
valid_f1 += f1_score
valid_exact += exact_match_score
print('epoch: %d | valid_f1: %f | valid_exact: %f'%(
epoch, valid_f1/len(valid_engine), valid_exact/len(valid_engine)
))
if epoch % args.save_freq == 0:
torch.save(fusion_net, 'model.cpt')
torch.save(fusion_net, 'model.final')