-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocess.py
More file actions
104 lines (86 loc) · 3.91 KB
/
preprocess.py
File metadata and controls
104 lines (86 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
import pandas as pd
import os.path
from utils import *
class BabiGen:
def __init__(self, task_num, batch_size, embeddings_size):
self.task_num = task_num
self.batch_size = batch_size
self.embeddings_size = embeddings_size
self.data = []
self.embed_weights = {}
self.load_data()
self.preprocess_data()
# self.store_data()
def load_data(self):
input = []
with open(f'babi/task_{self.task_num}_back.txt') as file:
task = file.read()
task = task.replace('.', ' .')
task = task.replace('?', ' ?')
sentences = task.split('\n')
sentences = np.array([np.array(s.split(' ')) for s in sentences])
# Prevent reading last line.
for s in sentences[:-1]:
if s[0] == '1':
input = []
# We want to see if any entry in s has the character
# '\t' which only answers have.
if any(np.core.defchararray.find(s, '\t') != -1):
s = np.concatenate((s[:-1], s[-1].split('\t')))
s = [self.embed(w) for w in s]
# TODO: modify to support multiple label indices
question, answer, label = s[0:-2], s[-2], s[-1]
input_flatten = np.array(input).reshape(-1, self.embeddings_size)
self.data.append((input_flatten, question, answer, label))
else:
# TODO: modify to store sentence number
s = [self.embed(w) for w in s[1:]]
input.append(s)
def preprocess_data(self):
data = np.array(self.data)
num_batches = len(data) // self.batch_size
data = data[:num_batches * self.batch_size]
data = data.reshape(num_batches, self.batch_size, -1)
data = np.swapaxes(data, 1, 2)
for i in range(len(data)):
max_length = max([data[i, 0, j].shape[0] for j in range(len(data[i, 0]))])
for j in range(len(data[i, 0])):
d = data[i, 0, j]
if max_length != d.shape[0]:
data[i, 0, j] = np.pad(d, ((0, max_length - d.shape[0]), (0, 0)), 'constant')
return data
def store_data(self):
file_name = f'babi/parsed/{self.task_num}_{self.batch_size}'
np.savez(file_name, self.data, self.embed('.'))
def embed(self, word):
if word not in self.embed_weights.keys():
self.embed_weights[word] = np.random.randn(self.embeddings_size)
return self.embed_weights[word]
def get_data(self):
return self.data
class BabiTask:
def __init__(self, batch_size, file_name):
self.epoch = 0
self.i = -1 # Batch index
self.batch_size = batch_size
file_name = f'babi/generated_data_two_fact_sup_{file_name}.npz'
file = np.load(file_name)
self.x, self.xq, self.y, self.sup = file['arr_0'], file['arr_1'], file['arr_2'], file['arr_3']
self.tx, self.txq, self.ty, self.tsup = file['arr_4'], file['arr_5'], file['arr_6'], file['arr_7']
self.vocab_size = file['arr_8']
self.eos_vector = file['arr_9']
def get_lengths(self):
return self.x.shape[1], self.xq.shape[1], self.vocab_size
def next_batch(self):
if (self.i + 2) * self.batch_size > len(self.x):
self.i = 0
self.epoch += 1
else:
self.i += 1
return (self.x[self.i * self.batch_size:(self.i + 1) * self.batch_size],
self.xq[self.i * self.batch_size:(self.i + 1) * self.batch_size],
self.y[self.i * self.batch_size:(self.i + 1) * self.batch_size],
self.sup[self.i * self.batch_size:(self.i + 1) * self.batch_size])
def dev_data(self, num_cases=100):
return self.tx[:num_cases], self.txq[:num_cases], self.ty[:num_cases], self.tsup[:num_cases]