-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmodel.py
More file actions
executable file
·402 lines (309 loc) · 13.1 KB
/
model.py
File metadata and controls
executable file
·402 lines (309 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import transformer
class Encoder(nn.Module):
def __init__(self, vocab_size, emb_size, hid_size, num_layers):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size, padding_idx=3)
self.encoder = nn.LSTM(emb_size, hid_size)
def forward(self, seqs, lens):
# Embed
emb_seqs = self.embedding(seqs)
# Sort by length
sort_idx = sorted(range(len(lens)), key=lambda i: -lens[i])
emb_seqs = emb_seqs[:,sort_idx]
lens = [lens[i] for i in sort_idx]
# Pack sequence
packed = torch.nn.utils.rnn.pack_padded_sequence(emb_seqs, lens)
# Forward pass through LSTM
outputs, hidden = self.encoder(packed)
# Unpack outputs
outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
# Unsort
unsort_idx = sorted(range(len(lens)), key=lambda i: sort_idx[i])
outputs = outputs[:,unsort_idx]
hidden = (hidden[0][:,unsort_idx], hidden[1][:,unsort_idx])
return outputs, hidden
class Decoder(nn.Module):
def __init__(self, emb_size, hid_size, vocab_size, num_layers, use_attn=True):
super(Decoder, self).__init__()
self.embedding = nn.Embedding(vocab_size, emb_size)
self.out = nn.Linear(hid_size, vocab_size)
self.use_attn = use_attn
self.hid_size = hid_size
if use_attn:
self.decoder = nn.LSTM(emb_size+hid_size, hid_size)
self.W_a = nn.Linear(hid_size * 2, hid_size)
self.v = nn.Linear(hid_size, 1)
else:
self.decoder = nn.LSTM(emb_size, hid_size)
def forward(self, hidden, last_word, encoder_outputs, ret_out=False, ret_logits=False, ret_attn=False):
if not self.use_attn:
embedded = self.embedding(last_word)
output, hidden = self.decoder(embedded, hidden)
if not ret_out:
return F.log_softmax(self.out(output), dim=2), hidden
else:
return F.log_softmax(self.out(output), dim=2), hidden, output
else:
embedded = self.embedding(last_word)
# Attn
h = hidden[0].repeat(encoder_outputs.size(0), 1, 1)
attn_energy = F.tanh(self.W_a(torch.cat((h, encoder_outputs), dim=2)))
attn_logits = self.v(attn_energy).squeeze(-1) - 1e5 * (encoder_outputs.sum(dim=2) == 0).float()
attn_weights = F.softmax(attn_logits, dim=0).permute(1,0).unsqueeze(1)
context_vec = attn_weights.bmm(encoder_outputs.permute(1,0,2)).permute(1,0,2)
# Concat with embeddings
rnn_input = torch.cat((context_vec, embedded), dim=2)
# Forward
output, hidden = self.decoder(rnn_input, hidden)
if ret_attn:
if not ret_out:
if ret_logits:
return self.out(output), hidden, attn_weights
else:
return F.log_softmax(self.out(output), dim=2), hidden, attn_weights
else:
return F.log_softmax(self.out(output), dim=2), hidden, output, attn_weights
else:
if not ret_out:
if ret_logits:
return self.out(output), hidden
else:
return F.log_softmax(self.out(output), dim=2), hidden
else:
return F.log_softmax(self.out(output), dim=2), hidden, output
class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, i2w, use_knowledge, args, test=False):
super(Seq2Seq, self).__init__()
self.args = args
self.use_knowledge = use_knowledge
# Model
self.encoder = encoder
self.decoder = decoder
# Vocab
self.i2w = i2w
self.w2i = {w:i for i,w in enumerate(i2w)}
# Training
if test:
self.criterion = nn.NLLLoss(ignore_index=self.w2i['_pad'], reduction='sum')
else:
self.criterion = nn.NLLLoss(ignore_index=self.w2i['_pad'])
self.optim = optim.Adam(lr=args.lr, params=self.parameters(), weight_decay=args.l2_norm)
def prep_batch(self, rows, wh=128, wk=64):
def _pad(arr, pad):
# Given an array of integer arrays, pad all arrays to the same length
lengths = [len(e) for e in arr]
max_len = max(lengths)
return [e+[pad]*(max_len-len(e)) for e in arr], lengths
# Split all rows
rows = [[e.split() for e in row] for row in rows]
# Form inputs
if self.use_knowledge:
inputs = [row[0][-wh:] + row[2][:wk] + ['_eos'] for row in rows]
else:
inputs = [row[0][-wh:] for row in rows]
# Input seq
inputs = [[self.w2i.get(w, self.w2i['_unk']) for w in inp] for inp in inputs]
input_seq, input_lens = _pad(inputs, self.w2i['_pad'])
input_seq = torch.cuda.LongTensor(input_seq).t()
# Target seq
targets = [[self.w2i.get(w, self.w2i['_unk']) for w in row[1]] for row in rows]
target_seq, target_lens = _pad(targets, pad=self.w2i['_pad'])
target_seq = torch.cuda.LongTensor(target_seq).t()
return input_seq, input_lens, target_seq, target_lens
def forward(self, input_seq, input_lens, target_seq, target_lens):
# Encoder
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_lens)
# Decoder
decoder_hidden = encoder_hidden
probas = torch.zeros(target_seq.size(0), target_seq.size(1), len(self.w2i)).cuda()
last_word = target_seq[0].unsqueeze(0)
for t in range(1,target_seq.size(0)):
# Pass through decoder
decoder_output, decoder_hidden = self.decoder(decoder_hidden, last_word, encoder_outputs)
# Save output
probas[t] = decoder_output
# Set new last word
last_word = target_seq[t].unsqueeze(0)
return probas
def train(self, input_seq, input_lens, target_seq, target_lens):
self.optim.zero_grad()
# Forward
proba = self.forward(input_seq, input_lens, target_seq, target_lens)
# Loss
loss = self.criterion(proba.view(-1, proba.size(-1)), target_seq.flatten())
# Backwards
loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), self.args.clip)
self.optim.step()
return loss.item()
def eval_ppl(self, input_seq, input_lens, target_seq, target_lens):
# Forward
proba = self.forward(input_seq, input_lens, target_seq, target_lens)
# Loss
loss = self.criterion(proba.view(-1, proba.size(-1)), target_seq.flatten())
return loss.item()
def decode(self, input_seq, input_lens, top_p=0, max_len=100, p_copy=0):
batch_size = input_seq.size(1)
predictions = torch.zeros((batch_size, max_len))
with torch.no_grad():
# Encoder
encoder_outputs, encoder_hidden = self.encoder(input_seq, input_lens)
# Decoder
decoder_hidden = encoder_hidden
last_word = torch.cuda.LongTensor([[self.w2i['_go'] for _ in range(batch_size)]])
# Input one-hot
input_oh = torch.eye(len(self.w2i))[input_seq].cuda()
for t in range(max_len):
# Pass through decoder
decoder_output, decoder_hidden, attn = self.decoder(decoder_hidden, last_word, encoder_outputs, ret_logits=top_p>0, ret_attn=True)
copy_prob = attn.bmm(input_oh.permute(1, 0, 2)).permute(1, 0, 2)
# Get top candidates
if top_p == 0:
topv, topi = (torch.exp(decoder_output) + p_copy*copy_prob).data.topk(1)
else:
probs = F.softmax(decoder_output, dim=-1) + p_copy*copy_prob
s_probs, s_inds = torch.sort(probs, descending=True)
cum_probs = torch.cumsum(s_probs, dim=-1)
# Remove all outside the nucleus
sinds_to_remove = cum_probs > top_p
# HuggingFace implementation did this to ensure first one is kept
sinds_to_remove[:,:,1:] = sinds_to_remove[:,:,:-1].clone()
sinds_to_remove[:,:,0] = 0
for b in range(s_inds.size(1)):
# Remove
inds_to_remove = s_inds[:,b][sinds_to_remove[:,b]]
# Set to be filtered in original
probs[0,b,inds_to_remove] = 0
# Sample
topi = torch.multinomial((probs).squeeze(0), 1)
topi = topi.view(-1)
predictions[:, t] = topi
# Set new last word
last_word = topi.detach().view(1, -1)
predicted_sentences = []
for sentence in predictions:
sent = []
for ind in sentence:
word = self.i2w[ind.long().item()]
if word == '_eos':
break
sent.append(word)
predicted_sentences.append(' '.join(sent))
return predicted_sentences
def save(self, name):
torch.save(self, name)
def load(self, name):
self.load_state_dict(torch.load(name).state_dict())
class Transformer(nn.Module):
def __init__(self, i2w, use_knowledge, args, test=False):
super(Transformer, self).__init__()
self.args = args
self.use_knowledge = use_knowledge
# Vocab
self.i2w = i2w
self.w2i = {w:i for i,w in enumerate(i2w)}
self.transformer = transformer.Transformer(len(i2w), len(i2w), src_pad_idx=self.w2i['_pad'], trg_pad_idx=self.w2i['_pad'])
# Training
if test:
self.criterion = nn.CrossEntropyLoss(ignore_index=self.w2i['_pad'], reduction='sum')
else:
self.criterion = nn.CrossEntropyLoss(ignore_index=self.w2i['_pad'])
self.optim = optim.Adam(lr=args.lr, params=self.parameters(), betas=(0.9, 0.997), eps=1e-09)
def prep_batch(self, rows, wh=64, wk=64):
def _pad(arr, pad):
# Given an array of integer arrays, pad all arrays to the same length
lengths = [len(e) for e in arr]
max_len = max(lengths)
return [e+[pad]*(max_len-len(e)) for e in arr], lengths
# Split all rows
rows = [[e.split() for e in row] for row in rows]
# Form inputs
if self.use_knowledge:
inputs = [row[0][-wh:] + row[2][:wk] + ['_eos'] for row in rows]
else:
inputs = [row[0][-wh:] for row in rows]
# Input seq
inputs = [[self.w2i.get(w, self.w2i['_unk']) for w in inp] for inp in inputs]
input_seq, input_lens = _pad(inputs, self.w2i['_pad'])
input_seq = torch.cuda.LongTensor(input_seq)
# Target seq
targets = [[self.w2i.get(w, self.w2i['_unk']) for w in row[1]] for row in rows]
target_seq, target_lens = _pad(targets, pad=self.w2i['_pad'])
target_seq = torch.cuda.LongTensor(target_seq)
return input_seq, input_lens, target_seq, target_lens
def forward(self, input_seq, input_lens, target_seq, target_lens):
return self.transformer(input_seq, target_seq)
def train(self, input_seq, input_lens, target_seq, target_lens):
self.optim.zero_grad()
# Forward
proba = self.forward(input_seq, input_lens, target_seq, target_lens)
# Loss
loss = self.criterion(proba.view(-1, proba.size(-1)), target_seq.flatten())
# Backwards
loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), self.args.clip)
self.optim.step()
return loss.item()
def eval_ppl(self, input_seq, input_lens, target_seq, target_lens):
# Forward
proba = self.forward(input_seq, input_lens, target_seq, target_lens)
# Loss
loss = self.criterion(proba.view(-1, proba.size(-1)), target_seq.flatten())
return loss.item()
def decode(self, input_seq, input_lens, top_p=0, max_len=100):
batch_size = input_seq.size(0)
predictions = [['_go'] for _ in range(batch_size)]
eos_seen = [False for _ in range(batch_size)]
def _pad(arr, pad):
# Given an array of integer arrays, pad all arrays to the same length
lengths = [len(e) for e in arr]
max_len = max(lengths)
return [e+[pad]*(max_len-len(e)) for e in arr], lengths
with torch.no_grad():
enc_output = self.transformer.enc(input_seq)
for t in range(max_len):
# Create the targets so far
targets = [[self.w2i.get(w, self.w2i['_unk']) for w in row + ['_pad']] for row in predictions]
target_seq, target_lens = _pad(targets, pad=self.w2i['_pad'])
target_seq = torch.cuda.LongTensor(target_seq)
# Pass through transformer
proba = F.softmax(self.transformer(input_seq, target_seq, enc_output=enc_output), dim=-1)[:,-1]
# Get top candidates
if top_p == 0:
topv, topi = proba.topk(1)
else:
s_probs, s_inds = torch.sort(proba, descending=True)
cum_probs = torch.cumsum(s_probs, dim=-1)
# Remove all outside the nucleus
sinds_to_remove = cum_probs > top_p
# HuggingFace implementation did this to ensure first one is kept
sinds_to_remove[:,1:] = sinds_to_remove[:,:-1].clone()
sinds_to_remove[:,0] = 0
for b in range(s_inds.size(0)):
# Remove
inds_to_remove = s_inds[b][sinds_to_remove[b]]
# Set to be filtered in original
proba[b,inds_to_remove] = 0
# Sample
topi = torch.multinomial(proba.squeeze(0), 1)
topi = topi.view(-1)
words = [self.i2w[e.item()] for e in topi]
for i in range(len(predictions)):
predictions[i].append(words[i])
if words[i] == '_eos':
eos_seen[i] = True
if all(eos_seen):
break
predicted_sentences = []
for sentence in predictions:
predicted_sentences.append(' '.join(sentence[1:-1 if '_eos' not in sentence else sentence.index('_eos')]))
return predicted_sentences
def save(self, name):
torch.save(self, name)
def load(self, name):
self.load_state_dict(torch.load(name).state_dict())