-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathdataset.py
More file actions
129 lines (102 loc) · 5.25 KB
/
dataset.py
File metadata and controls
129 lines (102 loc) · 5.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import pickle
import pandas as pd
import numpy as np
from typing import List,Union,Optional,Dict
class SlotTokenizer:
def __init__(self,data_list : List[str],embedding_table : Optional[pd.DataFrame] = None):
'''
Tokenizer for slot-based phoneme/grapheme representations.
Args:
data_list (List[str]) : word corpus from which to extract the token set.
embedding_table (pd.Dataframe) : Maps each token to a prespecified vector embedding.
'''
### Determine which values occur at each slot. We only assign tokens
### to these values.
num_slots = len(data_list[0])
slots = {slot:{} for slot in range(num_slots)}
for word in data_list:
for idx,char in enumerate(word):
if char == '_':
continue;
if slots[idx].get(char,False) is False:
slots[idx][char] = len(slots[idx])
self.slots = {slot:slots[slot] for slot in slots if len(slots[slot])}
self.embedding_table = embedding_table
### If [embedding_table] exists, we use its embedding vectors.
if embedding_table is not None:
self.embedding_size = len(slots) * len(embedding_table.columns)
### Otherwise, we create one-hot embeddings w/ dimensionality equal
### to the number of tokens.
else:
self.embedding_size = sum([len(slots[slot]) for slot in slots])
def __call__(self,word:str) -> torch.Tensor:
'''
Tokenize [word] and map it to a vector.
Args:
word (str) : string to tokenize
Returns:
Vector embedding for [word]
'''
embedding = torch.zeros((self.embedding_size))
marker = 0
### Iterative over characters in [word]
for idx,char in enumerate(word):
### If an empty slot, continue.
if idx not in self.slots:
continue;
### If [embedding_table] exists, replace the current slot in [embedding]
### with the row of [embedding_table] corresponding to [char].
if self.embedding_table is not None:
embedding[marker:marker+len(self.embedding_table.columns)] = torch.FloatTensor(
self.embedding_table.loc[char].to_numpy()
)
marker += len(self.embedding_table.columns)
### Otherwise, replace the current slot of [embedding] with a one-hot vector.
else:
if char != '_':
embedding[marker + self.slots[idx][char]] = 1
marker += len(self.slots[idx])
return embedding
class Monosyllabic_Dataset(torch.utils.data.Dataset):
def __init__(self,path_to_words,path_to_phon_mapping,path_to_sem,sample=True):
'''
Generic dataset for slot-based representations of
monosyllabic words.
Args:
path_to_words (str) : location of .csv file containing orthography and phonology.
path_to_phon_mapping (str) : location of .csv file containing phonetic features.
path_to_sem (str) : location of file (.npy or .npz) containing semantic embeddings.
sample (Optional[bool]) : If True, we sample words according to scaled
frequency.
'''
super(Monosyllabic_Dataset,self).__init__()
data = pd.read_csv(path_to_words).drop_duplicates()
### Parse orthography; create grapheme tokenizer
self.orthography = data['ort']
self.orthography_tokenizer = SlotTokenizer(self.orthography)
### Parse phonology; create phoneme tokenizer
self.phonology = data['pho']
phon_mapping = pd.read_csv(path_to_phon_mapping,sep="\t",header=None).set_index(0)
self.phonology_tokenizer = SlotTokenizer(self.phonology,phon_mapping)
### Parse semantics
semantics = torch.FloatTensor(np.load(path_to_sem)['data'])
self.semantics = semantics[:,(semantics==0).any(dim=0)]
### Parse and scale word frequencies
### TODO: allow user to adjust frequency scaling
self.frequencies = np.clip(np.sqrt(data['wf'])/(30000**.5),.05,1)
self.frequencies = self.frequencies/np.sum(self.frequencies)
self.sample = sample
def __len__(self) -> int:
return len(self.orthography)
def __getitem__(self,idx : Union[int,str]) -> Dict[str,torch.Tensor]:
if isinstance(idx,str):
idx = self.orthography.index[self.orthography.apply(lambda x: x.replace('_','')) == idx][0]
### If [self.sample], sample from word corpus.
if self.sample:
idx = np.random.choice(np.arange(self.__len__()),p=self.frequencies)
### Get orthography, phonology, and semantics vectors
orthography = self.orthography_tokenizer(self.orthography.iloc[idx])
phonology = self.phonology_tokenizer(self.phonology.iloc[idx])
semantics = self.semantics[idx]
return {'orthography':orthography,'phonology':phonology,'semantics':semantics}