-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample.py
More file actions
59 lines (43 loc) · 1.57 KB
/
sample.py
File metadata and controls
59 lines (43 loc) · 1.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def sen_to_index(sentence, mapping):
"""
Converts the given sentence into a list of indices corresponding to that mapping.
"""
temp = []
for k in sentence.split(' '):
temp.append(mapping[k])
return temp
def eva(model, datagen, warm_up, temperature = 0.2, pre_len = 30):
"""
This function is used to generate text from the trained model.
model: trained model for text generation
datagen: dataset oject created using Dataset_gen class
warm_up: Text used for warming up sampling process
temperature: used to induce a randomess in sampling.
pre_len: desired predicted length string.
"""
model.train()
inp_sentence = sen_to_index(warm_up.lower(), datagen.xdi)
pre_string = warm_up
hidden, cell = model.hidden_cell_init(1)
hidden = hidden.to(device)
cell = cell.to(device)
for j in range(len(inp_sentence)-1):
token = torch.LongTensor([[inp_sentence[j]]])
token = token.to(device)
_, (hidden,cell) = model(token, hidden, cell)
inp = inp_sentence[-1]
for k in range(pre_len):
token = torch.LongTensor([[inp]])
token = token.to(device)
output, (hidden, cell) = model(token, hidden, cell)
output = output.div(temperature)
index = torch.max(output, 1)[1].item()
pre_string += ' ' + datagen.idx[index]
inp = index
print(pre_string)