-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathconfig.py
More file actions
54 lines (47 loc) · 1.75 KB
/
config.py
File metadata and controls
54 lines (47 loc) · 1.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# -*- coding: utf-8 -*-
import os
import torch
class Config():
def __init__(self):
# general param
self.RETRAIN = True
self.USE_CUDA = False #torch.cuda.is_available()
# define the data paths
self.RAW_TRAIN_DATA = "./data/cards_250_7/cards_for_train"
self.RAW_TEST_DATA = "./data/cards_250_7/cards_for_eval"
# define the source path
self.SOURCE_DIR_PATH = {
"MODEL_DIR" : "./source/models/",
"SUMMARY_DIR" : "./source/summary/"
}
# define the file path
self.LABEL_TO_NAME_PATH = "./source/label_to_name_dict.pkl"
self.NAME_TO_LABEL_PATH = "./source/name_to_label_dict.pkl"
# check the path
self.check_dir()
# define the param of the training
self.WIDTH = 488
self.HEIGHT = 488
self.CHANNEL = 3
self.NUM_CLASS = 250
self.BATCH_SIZE = 30
self.NUM_EPOCHS = 500
self.LEARNING_RATE = 0.001
self.VALPERBATCH = 2
def check_dir(self):
'''
This function is used to check the dirs.if data path
does not exists, raise error.if source path does not
exits, make new dirs.
:return: None
'''
# check the data path
if not os.path.exists(self.RAW_TEST_DATA):
raise Exception("==> Error: Data path %s does not exist." % self.RAW_TEST_DATA)
if not os.path.exists(self.RAW_TRAIN_DATA):
raise Exception("==> Error: Data path %s does not exist." % self.RAW_TRAIN_DATA)
# check the source path
for name, path in self.SOURCE_DIR_PATH.items():
if not os.path.exists(path):
print("==> Creating %s : %s" % (name, path))
os.makedirs(path)