-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathpreprecess.py
More file actions
36 lines (31 loc) · 1.25 KB
/
preprecess.py
File metadata and controls
36 lines (31 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# -*- coding: utf-8 -*-
import os
import pickle
from config import Config
class preprocess():
def __init__(self, config):
self.config = config
def gen_label_dict(self):
# Get the training data path
train_path = self.config.RAW_TRAIN_DATA
# Get the dirname in the train data path
name_to_label_dict = dict()
n = 0
print("==> Generating name_to_label_dict.")
for name in os.listdir(train_path):
ph = os.path.join(train_path, name)
if os.path.isdir(ph):
label = name.split("_")[0] # Get the label
name_to_label_dict[label] = n
n += 1
pickle.dump(name_to_label_dict, open(self.config.NAME_TO_LABEL_PATH, "wb"))
print("There are %d labels" % len(name_to_label_dict))
print("==> Generating label_to_name_dict.")
label_to_name_dict = {l:n for n,l in name_to_label_dict.items()}
pickle.dump(label_to_name_dict, open(self.config.LABEL_TO_NAME_PATH, "wb"))
if __name__ == "__main__":
conf = Config()
pre = preprocess(conf)
if not os.path.exists(conf.LABEL_TO_NAME_PATH) or not os.path.exists(conf.NAME_TO_LABEL_PATH):
print("==> Generating label dict.")
pre.gen_label_dict()