forked from apcode/tensorflow_fasttext
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprocess_input.py
More file actions
146 lines (126 loc) · 5.09 KB
/
process_input.py
File metadata and controls
146 lines (126 loc) · 5.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Process input data into tensorflow examples, to ease training.
Input data is in one of two formats:
- facebook's format used in their fastText library.
- two text files, one with input text per line, the other a label per line.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import re
import sys
import tensorflow as tf
from collections import Counter
from six.moves import zip
from nltk.tokenize import word_tokenize
tf.flags.DEFINE_string("facebook_input", None,
"Input file in facebook train|test format")
tf.flags.DEFINE_string("text_input", None,
"""Input text file containing one text phrase per line.
Must have --labels defined""")
tf.flags.DEFINE_string("labels", None,
"""Input text file containing one label for
classification per line.
Must have --text_input defined.""")
tf.flags.DEFINE_string("ngrams", None,
"list of ngram sizes to create, e.g. --ngrams=2,3,4,5")
tf.flags.DEFINE_string("output_dir", ".",
"Directory to store resulting vector models and checkpoints in")
tf.flags.DEFINE_integer("num_shards", 1,
"Number of outputfiles to create")
FLAGS = tf.flags.FLAGS
def CleanText(text):
return word_tokenise(text.lower())
def NGrams(words, ngrams):
nglist = []
for word in words:
for ng in ngrams:
nglist.extend([word[n:n+ng] for n in range(len(word)-ng+1)])
return nglist
def ParseFacebookInput(inputfile, ngrams):
examples = []
for line in open(inputfile):
words = line.split()
# label is first field with __label__ removed
match = re.match(r'__label__([0-9]+)', words[0])
label = int(match.group(1)) if match else None
# Strip out label and first ,
words = words[2:]
examples.append({
"text": words,
"label": label - 1
})
if ngrams:
examples[-1]["ngrams"] = NGrams(words, ngrams)
return examples
def ParseTextInput(textfile, labelsfie, ngrams):
examples = []
with open(textfile) as f1, open(labelsfile) as f2:
for text, label in zip(f1, f2):
examples.append({
"text": CleanText(text),
"label": int(label) - 1,
})
if ngrams:
examples[-1]["ngrams"] = NGrams(words, ngrams)
return examples
def WriteExamples(examples, outputfile, num_shards):
"""Write examles in TFRecord format.
Args:
examples: list of feature dicts.
{'text': [words], 'label': [labels]}
outputfile: full pathname of output file
"""
shard = 0
num_per_shard = len(examples) / num_shards + 1
for n, example in enumerate(examples):
if n % num_per_shard == 0:
shard += 1
writer = tf.python_io.TFRecordWriter(outputfile + '-%d-of-%d' % \
(shard, num_shards))
record = tf.train.Example()
text = [tf.compat.as_bytes(x) for x in example["text"]]
record.features.feature["text"].bytes_list.value.extend(text)
record.features.feature["label"].int64_list.value.append(example["label"])
if "ngrams" in example:
ngrams = [tf.compat.as_bytes(x) for x in example["ngrams"]]
record.features.feature["ngrams"].bytes_list.value.extend(ngrams)
writer.write(record.SerializeToString())
def WriteVocab(examples, vocabfile, labelfile):
words = Counter()
labels = set()
for example in examples:
words.update(example["text"])
labels.add(example["label"])
with open(vocabfile, "w") as f:
# Write out vocab in most common first order
# We need this as NCE loss in TF uses Zipf distribution
for word in words.most_common():
f.write(word[0] + '\n')
with open(labelfile, "w") as f:
labels = sorted(list(labels))
for label in labels:
f.write(str(label) + '\n')
def main(_):
# Check flags
if not (FLAGS.facebook_input or (FLAGS.text_input and FLAGS.labels)):
print >>sys.stderr, \
"Error: You must define either facebook_input or both text_input and labels"
sys.exit(1)
ngrams = None
if FLAGS.ngrams:
ngrams = [int(g) for g in FLAGS.ngrams.split(',')]
ngrams = [g for g in ngrams if (g > 1 and g < 7)]
if FLAGS.facebook_input:
inputfile = FLAGS.facebook_input
examples = ParseFacebookInput(FLAGS.facebook_input, ngrams)
else:
inputfile = FLAGS.text_input
examples = ParseTextInput(FLAGS.text_input, FLAGS.labels, ngrams)
outputfile = os.path.join(FLAGS.output_dir, inputfile + ".tfrecords")
WriteExamples(examples, outputfile, FLAGS.num_shards)
vocabfile = os.path.join(FLAGS.output_dir, inputfile + ".vocab")
labelfile = os.path.join(FLAGS.output_dir, inputfile + ".labels")
WriteVocab(examples, vocabfile, labelfile)
if __name__ == '__main__':
tf.app.run()