forked from spyysalo/bert-span-classifier
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_tfrecords.py
More file actions
120 lines (97 loc) · 3.37 KB
/
create_tfrecords.py
File metadata and controls
120 lines (97 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
#!/usr/bin/env python3
import sys
import tensorflow as tf
import bert_tokenization as tokenization
from collections import OrderedDict
from argparse import ArgumentParser
from common import load_labels, tsv_generator
from config import DEFAULT_SEQ_LEN
def argparser():
ap = ArgumentParser()
ap.add_argument(
'--input_file', required=True,
help='Input data in TSV format'
)
ap.add_argument(
'--output_file', required=True,
help='Output TF example file'
)
ap.add_argument(
'--labels', required=True,
help='File containing list of labels'
)
ap.add_argument(
'--vocab_file', required=True,
help='Vocabulary file that BERT model was trained on'
)
ap.add_argument(
'--max_seq_length', type=int, default=DEFAULT_SEQ_LEN,
help='Maximum input sequence length in WordPieces'
)
ap.add_argument(
'--do_lower_case', default=False, action='store_true',
help='Lower case input text (for uncased models)'
)
ap.add_argument(
'--replace_span', default=None,
help='Replace span text with given special token'
)
ap.add_argument(
'--label_field', type=int, default=-4,
help='Index of label in TSV data (1-based)'
)
ap.add_argument(
'--text_fields', type=int, default=-3,
help='Index of first text field in TSV data (1-based)'
)
ap.add_argument(
'--max_examples', type=int, default=None,
help='Maximum number of examples to generate'
)
return ap
class Example(object):
def __init__(self, x, y):
assert len(x) == 2
self.token_ids = x[0]
self.segment_ids = x[1]
self.label = y
def to_tf_example(self):
features = OrderedDict()
features['Input-Token'] = create_int_feature(self.token_ids)
features['Input-Segment'] = create_int_feature(self.segment_ids)
features['label'] = create_int_feature([self.label])
return tf.train.Example(features=tf.train.Features(feature=features))
def __str__(self):
return 'token_ids: {}\nsegment_ids: {}\nlabel: {}'.format(
' '.join(str(t) for t in self.token_ids),
' '.join(str(s) for s in self.segment_ids),
self.label
)
def create_int_feature(values):
feature = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return feature
def write_examples(examples, output_file):
count = 0
with tf.io.TFRecordWriter(output_file) as writer:
for example in examples:
tf_example = example.to_tf_example()
writer.write(tf_example.SerializeToString())
count += 1
print('wrote {} examples to {}'.format(count, output_file), file=sys.stderr)
def main(argv):
args = argparser().parse_args(argv[1:])
tokenizer = tokenization.FullTokenizer(
vocab_file=args.vocab_file,
do_lower_case=args.do_lower_case
)
label_list = load_labels(args.labels)
label_map = { l: i for i, l in enumerate(label_list) }
examples = []
for x, y in tsv_generator(args.input_file, tokenizer, label_map, args):
examples.append(Example(x, y))
if args.max_examples and len(examples) >= args.max_examples:
break
write_examples(examples, args.output_file)
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv))