-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptions.py
More file actions
139 lines (120 loc) · 3.26 KB
/
options.py
File metadata and controls
139 lines (120 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gin
import inspect
def parse_common_args(parser):
parser.add_argument(
"config_files",
metavar="N",
type=str,
nargs="+",
help="path to gin config files",
)
# device and system config
parser.add_argument(
"-j",
"--workers",
default=4,
type=int,
metavar="N",
help="number of data loading workers (default: 4)",
)
parser.add_argument(
"--seed", default=None, type=int, help="seed for initializing training"
)
parser.add_argument(
"--rank", default=0, type=int, help="node rank for distributed training"
)
# IO config
parser.add_argument(
"--log-dir",
type=str,
default="./out",
help="directory to save logs (default: ./out)",
)
return parser
def parse_train_args(parser):
parser = parse_common_args(parser)
# IO config
parser.add_argument(
"--save-dir",
type=str,
default="./checkpoints",
help="directory to save model checkpoint (default: ./checkpoints)",
)
parser.add_argument(
"-p",
"--print-freq",
default=1000,
type=int,
metavar="N",
help="print frequency (default: 1000 iterations)",
)
parser.add_argument(
"--eval-freq",
default=1,
type=int,
metavar="N",
help="model evaluation frequency (default: 1 epochs)",
)
parser.add_argument(
"--save-freq",
default=10,
type=int,
metavar="N",
help="save checkpoints frequency (default: 10 epochs)",
)
# device and system config
parser.add_argument(
"--world-size",
default=1,
type=int,
help="number of nodes for distributed training",
)
parser.add_argument(
"--dist-url",
default="tcp://127.0.0.1:40404",
type=str,
help="url used to set up distributed training",
)
parser.add_argument(
"--dist-backend", default="nccl", type=str, help="distributed backend"
)
return parser
def parse_test_args(parser):
parser = parse_common_args(parser)
parser.add_argument("--gpu", default=0, type=int, help="GPU id")
parser.add_argument("--plot", action="store_true", help="plot curves")
parser.add_argument(
"--save-features", action="store_true", help="save extracted features"
)
return parser
@gin.configurable("train", denylist=["args"])
def inject_to_train_args(
args,
continue_training=False,
max_steps=gin.REQUIRED,
resume=gin.REQUIRED,
start_epoch=gin.REQUIRED,
epochs=gin.REQUIRED,
batch_size=gin.REQUIRED,
val_batch_size=gin.REQUIRED,
lr_dict=gin.REQUIRED,
gamma=gin.REQUIRED,
max_norm=gin.REQUIRED,
):
sig = inspect.signature(inject_to_train_args)
for name in sig.parameters:
if name != "args":
setattr(args, name, eval(name))
return args
@gin.configurable("test", denylist=["args"])
def inject_to_test_args(
args,
max_steps=gin.REQUIRED,
resume=gin.REQUIRED,
batch_size=gin.REQUIRED,
):
sig = inspect.signature(inject_to_test_args)
for name in sig.parameters:
if name != "args":
setattr(args, name, eval(name))
return args