-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathword_dataset.py
More file actions
65 lines (51 loc) · 2 KB
/
word_dataset.py
File metadata and controls
65 lines (51 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import codecs
import nltk
from unidecode import unidecode
import torch
from word_embedding import WordEmbedding
from torch.utils.data import DataLoader, Dataset
class WordDataset(Dataset):
def __init__(self, text_fn, embedding, chunk_size=12):
self.embedding = embedding
self.chunk_size = chunk_size
# read the data
# input_file = 'data/keywell_corpus.txt'
fp = codecs.open(text_fn, 'r', 'utf-8')
words = nltk.word_tokenize(fp.read())
words = map(unidecode, words)
# filter/preprocess words
words = [word.replace(',', '') for word in words]
words = [word.lower() for word in words]
# split on hyphens
for word in words:
if '-' in word:
dash_words = word.split('-')
words.remove(word)
words.extend(dash_words)
words = [word for word in words if word]
self.words = words
def __len__(self):
return len(self.words)
def dim(self):
return self.embedding.dim
def get_chunk(self):
chunk_words = []
got_good_words = False
while not got_good_words:
sta_ind = np.random.randint(0, len(self) - self.chunk_size - 1)
end_ind = sta_ind + self.chunk_size
chunk_words = self.words[sta_ind:end_ind]
got_good_words = all([word in embedding for word in chunk_words])
vec_chunk = np.stack([self.embedding[word] for word in chunk_words])
return torch.from_numpy(vec_chunk)
def get_chunks(self, n_chunks):
return torch.stack([self.get_chunk() for _ in range(n_chunks)])
if __name__ == '__main__':
embedding_fn = '/Users/bkeating/nltk_data/embeddings/glove/glove.6B.100d.txt'
embedding = WordEmbedding(embedding_fn)
dataset = WordDataset('data/keywell_corpus.txt', embedding)
chunk = dataset.get_chunk()
print(chunk.size())
chunks = dataset.get_chunks(20)
print(chunks.size())