-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathonehot.py
More file actions
34 lines (26 loc) · 1 KB
/
onehot.py
File metadata and controls
34 lines (26 loc) · 1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Onehot(object):
def __init__(self, tasks_data):
self.words = set()
self.word_to_index = {}
self.index_to_word = {}
for td in tasks_data:
for data in [td.train_data, td.valid_data, td.test_data]:
for story, answers in data:
for line in story:
for word in line:
self.words.add(word)
for ans in answers:
self.words.add(ans)
for index, word in enumerate(self.words):
self.word_to_index[word] = index
self.index_to_word[index] = word
self.num_words = len(self.words)
def get_encoding(self, word):
encoding = [0] * self.num_words
index = self.word_to_index[word]
encoding[index] = 1
return encoding
def get_word(self, encoding):
max_value = max(encoding)
index = encoding.index(max_value)
return self.index_to_word[index]